123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- # -*- coding:utf-8 -*-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import re
- import sys
- sys.path.insert(0, ".")
- import time
- # from paddlehub.common.logger import logger
- # from paddlehub.module.module import moduleinfo, runnable, serving
- import cv2
- import numpy as np
- # import paddlehub as hub
- from tools.infer.utility import base64_to_cv2
- from tools.infer.predict_system import TextSystem
- class OCRSystem:
- def __init__(self, use_gpu=False, enable_mkldnn=False):
- """
- initialize with the necessary elements
- """
- from deploy.hubserving.ocr_system.params import read_params
- cfg = read_params()
- cfg.use_gpu = use_gpu
- if use_gpu:
- try:
- _places = os.environ["CUDA_VISIBLE_DEVICES"]
- int(_places[0])
- print("use gpu: ", use_gpu)
- print("CUDA_VISIBLE_DEVICES: ", _places)
- cfg.gpu_mem = 8000
- except:
- raise RuntimeError(
- "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
- )
- cfg.ir_optim = True
- cfg.enable_mkldnn = enable_mkldnn
- self.text_sys = TextSystem(cfg)
- def read_images(self, paths=[]):
- images = []
- for img_path in paths:
- assert os.path.isfile(
- img_path), "The {} isn't a valid file.".format(img_path)
- img = cv2.imread(img_path)
- if img is None:
- print("error in loading image:{}".format(img_path))
- continue
- images.append(img)
- return images
- def predict(self, images=[], paths=[], **kwargs):
- """
- Get the chinese texts in the predicted images.
- Args:
- images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
- paths (list[str]): The paths of images. If paths not images
- Returns:
- res (list): The result of chinese texts and save path of images.
- """
- if images != [] and isinstance(images, list) and paths == []:
- predicted_data = images
- elif images == [] and isinstance(paths, list) and paths != []:
- predicted_data = self.read_images(paths)
- else:
- raise TypeError("The input data is inconsistent with expectations.")
- assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
- all_results = []
- for img in predicted_data:
- # 初始化关键字
- self.inv_no = ""
- self.inv_id = ""
- self.inv_date = ""
- self.inv_money = []
- self.inv_payee = ""
- self.inv_review = ""
- self.inv_drawer = ""
- self.inv_company = ["", ""]
- self.inv_identifier = ["", ""]
- if img is None:
- print("error in loading image")
- all_results.append([])
- continue
- # 图片预处理
- img = self.resizeImg(img, 640)
- starttime = time.time()
- dt_boxes, rec_res = self.text_sys(img)
- elapse = time.time() - starttime
- print("Predict time: {}".format(elapse))
- dt_num = len(dt_boxes)
- rec_res_final = []
- text_list = [] # 结果集
- for dno in range(dt_num):
- text, score = rec_res[dno]
- if score > 0.8:
- text_list.append(text)
- # print(text)
- self.getInformation(text, kwargs['invoice_type'])
- inv_text = ''.join(text_list)
- # print(inv_text)
- self.getInformationAgain(inv_text, kwargs['invoice_type'])
- all_results.append({
- 'no': self.inv_no,
- 'id': self.inv_id,
- 'date': self.inv_date,
- 'money': (max(self.inv_money) if len(self.inv_money) > 0 else ""),
- 'payee': self.inv_payee,
- 'review': self.inv_review,
- 'drawer': self.inv_drawer,
- 'company': self.inv_company,
- 'identifier': self.inv_identifier
- })
- return all_results
- def getInformationAgain(self, string, invoice_type):
- if self.inv_no == "":
- pt = re.compile(r'N[\w|\s]?(\d{8})', re.M)
- information_list = pt.findall(string)
- self.inv_no = information_list[0] if len(information_list) != 0 else ""
- if self.inv_no == "":
- pt = re.compile(r'号码:(\d{8})', re.M)
- information_list = pt.findall(string)
- self.inv_no = information_list[0] if len(information_list) != 0 else ""
- if self.inv_id == "":
- if invoice_type == 1:
- pt = re.compile(r'(\d{12})N', re.M)
- else:
- pt = re.compile(r'(\d{10})N', re.M)
- information_list = pt.findall(string)
- self.inv_id = information_list[0] if len(information_list) != 0 else ""
- if self.inv_id == "":
- if invoice_type == 1:
- pt = re.compile(r'代码:(\d{12})', re.M)
- else:
- pt = re.compile(r'代码:(\d{10})', re.M)
- information_list = pt.findall(string)
- self.inv_id = information_list[0] if len(information_list) != 0 else ""
- if self.inv_company[1] == '':
- pt = re.compile(r'称:(.*?)[-*+></\d购]?[纳税]', re.M)
- information_list = pt.findall(string)
- if len(information_list) != 0:
- for i in range(len(self.inv_company)):
- if len(information_list) != 0:
- self.inv_company[i] = information_list.pop(0)
- if self.inv_identifier[1] == '':
- pt = re.compile(r'别号:([a-zA-Z\d]{18})', re.M)
- information_list = pt.findall(string)
- if len(information_list) != 0:
- for i in range(len(self.inv_identifier)):
- if len(information_list) != 0:
- self.inv_identifier[i] = information_list.pop(0)
- if self.inv_payee == "":
- pt = re.compile(r'款人:(.*)复', re.M)
- information_list = pt.findall(string)
- self.inv_payee = information_list[0] if len(information_list) != 0 else ""
- if self.inv_review == "":
- pt = re.compile(r'复核:(.*)开', re.M)
- information_list = pt.findall(string)
- self.inv_review = information_list[0] if len(information_list) != 0 else ""
- if self.inv_drawer == "":
- pt = re.compile(r'票人:(.*)', re.M)
- information_list = pt.findall(string)
- self.inv_drawer = information_list[0] if len(information_list) != 0 else ""
- # 信息关键字提取
- def getInformation(self, string, invoice_type):
- if self.inv_no == "":
- pt = re.compile(r'N[\w]?(\d{8})')
- information_list = pt.findall(string)
- self.inv_no = information_list[0] if len(information_list) != 0 else ""
- if self.inv_no != "":
- return True
- if self.inv_id == "":
- if invoice_type == 1:
- pt = re.compile(r'^\d{12}$')
- else:
- pt = re.compile(r'^\d{10}$')
- information_list = pt.findall(string)
- self.inv_id = information_list[0] if len(information_list) != 0 else ""
- if self.inv_id != "":
- return True
- if self.inv_date == "":
- pt = re.compile(r'(\d*年\d*月\d*日)', re.M)
- information_list = pt.findall(string)
- self.inv_date = information_list[0] if len(information_list) != 0 else ""
- if self.inv_date != "":
- return True
- if self.inv_payee == "":
- pt = re.compile(r'款人:(.*)$', re.M)
- information_list = pt.findall(string)
- self.inv_payee = information_list[0] if len(information_list) != 0 else ""
- if self.inv_payee != "":
- return True
- if self.inv_review == "":
- pt = re.compile(r'复核:(.*)$', re.M)
- information_list = pt.findall(string)
- self.inv_review = information_list[0] if len(information_list) != 0 else ""
- if self.inv_review != "":
- return True
- if self.inv_drawer == "":
- pt = re.compile(r'票人:(.*)$', re.M)
- information_list = pt.findall(string)
- self.inv_drawer = information_list[0] if len(information_list) != 0 else ""
- if self.inv_drawer != "":
- return True
- if self.inv_identifier[1] == '':
- pt = re.compile(r'^[纳税人识别号:]?([a-zA-Z\d]{18})$', re.M)
- information_list = pt.findall(string)
- if len(information_list) != 0:
- for i in range(len(self.inv_identifier)):
- if self.inv_identifier[i] == '':
- self.inv_identifier[i] = information_list[0]
- return True
- if self.inv_identifier[1] == '':
- pt = re.compile(r'称:(.*)$', re.M)
- information_list = pt.findall(string)
- if len(information_list) != 0:
- for i in range(len(self.inv_company)):
- if self.inv_company[i] == '':
- self.inv_company[i] = information_list[0]
- return True
- pt = re.compile(r'(\d*\.\d*)')
- information_list = pt.findall(string)
- if len(information_list) != 0:
- self.inv_money.append(float(information_list[0]))
- return True
- # 固定尺寸
- def resizeImg(self, image, height=900):
- h, w = image.shape[:2]
- pro = height / h
- size = (int(w * pro), int(height))
- img = cv2.resize(image, size)
- return img
- if __name__ == '__main__':
- ocr = OCRSystem()
- image_path = [
- './doc/imgs/11.jpg',
- './doc/imgs/12.jpg',
- ]
- res = ocr.predict(paths=image_path)
- print(res)
|