# -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import re import sys sys.path.insert(0, ".") import time # from paddlehub.common.logger import logger # from paddlehub.module.module import moduleinfo, runnable, serving import cv2 import numpy as np # import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_system import TextSystem class OCRSystem: def __init__(self, use_gpu=False, enable_mkldnn=False): """ initialize with the necessary elements """ from deploy.hubserving.ocr_system.params import read_params cfg = read_params() cfg.use_gpu = use_gpu if use_gpu: try: _places = os.environ["CUDA_VISIBLE_DEVICES"] int(_places[0]) print("use gpu: ", use_gpu) print("CUDA_VISIBLE_DEVICES: ", _places) cfg.gpu_mem = 8000 except: raise RuntimeError( "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." ) cfg.ir_optim = True cfg.enable_mkldnn = enable_mkldnn self.text_sys = TextSystem(cfg) def read_images(self, paths=[]): images = [] for img_path in paths: assert os.path.isfile( img_path), "The {} isn't a valid file.".format(img_path) img = cv2.imread(img_path) if img is None: print("error in loading image:{}".format(img_path)) continue images.append(img) return images def predict(self, images=[], paths=[], **kwargs): """ Get the chinese texts in the predicted images. Args: images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths paths (list[str]): The paths of images. If paths not images Returns: res (list): The result of chinese texts and save path of images. """ if images != [] and isinstance(images, list) and paths == []: predicted_data = images elif images == [] and isinstance(paths, list) and paths != []: predicted_data = self.read_images(paths) else: raise TypeError("The input data is inconsistent with expectations.") assert predicted_data != [], "There is not any image to be predicted. Please check the input data." all_results = [] for img in predicted_data: # 初始化关键字 self.inv_no = "" self.inv_id = "" self.inv_date = "" self.inv_money = [] self.inv_payee = "" self.inv_review = "" self.inv_drawer = "" self.inv_company = ["", ""] self.inv_identifier = ["", ""] if img is None: print("error in loading image") all_results.append([]) continue # 图片预处理 img = self.resizeImg(img, 640) starttime = time.time() dt_boxes, rec_res = self.text_sys(img) elapse = time.time() - starttime print("Predict time: {}".format(elapse)) dt_num = len(dt_boxes) rec_res_final = [] text_list = [] # 结果集 for dno in range(dt_num): text, score = rec_res[dno] if score > 0.8: text_list.append(text) # print(text) self.getInformation(text, kwargs['invoice_type']) inv_text = ''.join(text_list) # print(inv_text) self.getInformationAgain(inv_text, kwargs['invoice_type']) all_results.append({ 'no': self.inv_no, 'id': self.inv_id, 'date': self.inv_date, 'money': (max(self.inv_money) if len(self.inv_money) > 0 else ""), 'payee': self.inv_payee, 'review': self.inv_review, 'drawer': self.inv_drawer, 'company': self.inv_company, 'identifier': self.inv_identifier }) return all_results def getInformationAgain(self, string, invoice_type): if self.inv_no == "": pt = re.compile(r'N[\w|\s]?(\d{8})', re.M) information_list = pt.findall(string) self.inv_no = information_list[0] if len(information_list) != 0 else "" if self.inv_no == "": pt = re.compile(r'号码:(\d{8})', re.M) information_list = pt.findall(string) self.inv_no = information_list[0] if len(information_list) != 0 else "" if self.inv_id == "": if invoice_type == 1: pt = re.compile(r'(\d{12})N', re.M) else: pt = re.compile(r'(\d{10})N', re.M) information_list = pt.findall(string) self.inv_id = information_list[0] if len(information_list) != 0 else "" if self.inv_id == "": if invoice_type == 1: pt = re.compile(r'代码:(\d{12})', re.M) else: pt = re.compile(r'代码:(\d{10})', re.M) information_list = pt.findall(string) self.inv_id = information_list[0] if len(information_list) != 0 else "" if self.inv_company[1] == '': pt = re.compile(r'称:(.*?)[-*+></\d购]?[纳税]', re.M) information_list = pt.findall(string) if len(information_list) != 0: for i in range(len(self.inv_company)): if len(information_list) != 0: self.inv_company[i] = information_list.pop(0) if self.inv_identifier[1] == '': pt = re.compile(r'别号:([a-zA-Z\d]{18})', re.M) information_list = pt.findall(string) if len(information_list) != 0: for i in range(len(self.inv_identifier)): if len(information_list) != 0: self.inv_identifier[i] = information_list.pop(0) if self.inv_payee == "": pt = re.compile(r'款人:(.*)复', re.M) information_list = pt.findall(string) self.inv_payee = information_list[0] if len(information_list) != 0 else "" if self.inv_review == "": pt = re.compile(r'复核:(.*)开', re.M) information_list = pt.findall(string) self.inv_review = information_list[0] if len(information_list) != 0 else "" if self.inv_drawer == "": pt = re.compile(r'票人:(.*)', re.M) information_list = pt.findall(string) self.inv_drawer = information_list[0] if len(information_list) != 0 else "" # 信息关键字提取 def getInformation(self, string, invoice_type): if self.inv_no == "": pt = re.compile(r'N[\w]?(\d{8})') information_list = pt.findall(string) self.inv_no = information_list[0] if len(information_list) != 0 else "" if self.inv_no != "": return True if self.inv_id == "": if invoice_type == 1: pt = re.compile(r'^\d{12}$') else: pt = re.compile(r'^\d{10}$') information_list = pt.findall(string) self.inv_id = information_list[0] if len(information_list) != 0 else "" if self.inv_id != "": return True if self.inv_date == "": pt = re.compile(r'(\d*年\d*月\d*日)', re.M) information_list = pt.findall(string) self.inv_date = information_list[0] if len(information_list) != 0 else "" if self.inv_date != "": return True if self.inv_payee == "": pt = re.compile(r'款人:(.*)$', re.M) information_list = pt.findall(string) self.inv_payee = information_list[0] if len(information_list) != 0 else "" if self.inv_payee != "": return True if self.inv_review == "": pt = re.compile(r'复核:(.*)$', re.M) information_list = pt.findall(string) self.inv_review = information_list[0] if len(information_list) != 0 else "" if self.inv_review != "": return True if self.inv_drawer == "": pt = re.compile(r'票人:(.*)$', re.M) information_list = pt.findall(string) self.inv_drawer = information_list[0] if len(information_list) != 0 else "" if self.inv_drawer != "": return True if self.inv_identifier[1] == '': pt = re.compile(r'^[纳税人识别号:]?([a-zA-Z\d]{18})$', re.M) information_list = pt.findall(string) if len(information_list) != 0: for i in range(len(self.inv_identifier)): if self.inv_identifier[i] == '': self.inv_identifier[i] = information_list[0] return True if self.inv_identifier[1] == '': pt = re.compile(r'称:(.*)$', re.M) information_list = pt.findall(string) if len(information_list) != 0: for i in range(len(self.inv_company)): if self.inv_company[i] == '': self.inv_company[i] = information_list[0] return True pt = re.compile(r'(\d*\.\d*)') information_list = pt.findall(string) if len(information_list) != 0: self.inv_money.append(float(information_list[0])) return True # 固定尺寸 def resizeImg(self, image, height=900): h, w = image.shape[:2] pro = height / h size = (int(w * pro), int(height)) img = cv2.resize(image, size) return img if __name__ == '__main__': ocr = OCRSystem() image_path = [ './doc/imgs/11.jpg', './doc/imgs/12.jpg', ] res = ocr.predict(paths=image_path) print(res)