module.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # -*- coding:utf-8 -*-
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import os
  6. import re
  7. import sys
  8. sys.path.insert(0, ".")
  9. import time
  10. # from paddlehub.common.logger import logger
  11. # from paddlehub.module.module import moduleinfo, runnable, serving
  12. import cv2
  13. import numpy as np
  14. # import paddlehub as hub
  15. from tools.infer.utility import base64_to_cv2
  16. from tools.infer.predict_system import TextSystem
  17. class OCRSystem:
  18. def __init__(self, use_gpu=False, enable_mkldnn=False):
  19. """
  20. initialize with the necessary elements
  21. """
  22. from deploy.hubserving.ocr_system.params import read_params
  23. cfg = read_params()
  24. cfg.use_gpu = use_gpu
  25. if use_gpu:
  26. try:
  27. _places = os.environ["CUDA_VISIBLE_DEVICES"]
  28. int(_places[0])
  29. print("use gpu: ", use_gpu)
  30. print("CUDA_VISIBLE_DEVICES: ", _places)
  31. cfg.gpu_mem = 8000
  32. except:
  33. raise RuntimeError(
  34. "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
  35. )
  36. cfg.ir_optim = True
  37. cfg.enable_mkldnn = enable_mkldnn
  38. self.text_sys = TextSystem(cfg)
  39. def read_images(self, paths=[]):
  40. images = []
  41. for img_path in paths:
  42. assert os.path.isfile(
  43. img_path), "The {} isn't a valid file.".format(img_path)
  44. img = cv2.imread(img_path)
  45. if img is None:
  46. print("error in loading image:{}".format(img_path))
  47. continue
  48. images.append(img)
  49. return images
  50. def predict(self, images=[], paths=[], **kwargs):
  51. """
  52. Get the chinese texts in the predicted images.
  53. Args:
  54. images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
  55. paths (list[str]): The paths of images. If paths not images
  56. Returns:
  57. res (list): The result of chinese texts and save path of images.
  58. """
  59. if images != [] and isinstance(images, list) and paths == []:
  60. predicted_data = images
  61. elif images == [] and isinstance(paths, list) and paths != []:
  62. predicted_data = self.read_images(paths)
  63. else:
  64. raise TypeError("The input data is inconsistent with expectations.")
  65. assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
  66. all_results = []
  67. for img in predicted_data:
  68. # 初始化关键字
  69. self.inv_no = ""
  70. self.inv_id = ""
  71. self.inv_date = ""
  72. self.inv_money = []
  73. self.inv_payee = ""
  74. self.inv_review = ""
  75. self.inv_drawer = ""
  76. self.inv_company = ["", ""]
  77. self.inv_identifier = ["", ""]
  78. if img is None:
  79. print("error in loading image")
  80. all_results.append([])
  81. continue
  82. # 图片预处理
  83. img = self.resizeImg(img, 640)
  84. starttime = time.time()
  85. dt_boxes, rec_res = self.text_sys(img)
  86. elapse = time.time() - starttime
  87. print("Predict time: {}".format(elapse))
  88. dt_num = len(dt_boxes)
  89. rec_res_final = []
  90. text_list = [] # 结果集
  91. for dno in range(dt_num):
  92. text, score = rec_res[dno]
  93. if score > 0.8:
  94. text_list.append(text)
  95. # print(text)
  96. self.getInformation(text, kwargs['invoice_type'])
  97. inv_text = ''.join(text_list)
  98. # print(inv_text)
  99. self.getInformationAgain(inv_text, kwargs['invoice_type'])
  100. all_results.append({
  101. 'no': self.inv_no,
  102. 'id': self.inv_id,
  103. 'date': self.inv_date,
  104. 'money': (max(self.inv_money) if len(self.inv_money) > 0 else ""),
  105. 'payee': self.inv_payee,
  106. 'review': self.inv_review,
  107. 'drawer': self.inv_drawer,
  108. 'company': self.inv_company,
  109. 'identifier': self.inv_identifier
  110. })
  111. return all_results
  112. def getInformationAgain(self, string, invoice_type):
  113. if self.inv_no == "":
  114. pt = re.compile(r'N[\w|\s]?(\d{8})', re.M)
  115. information_list = pt.findall(string)
  116. self.inv_no = information_list[0] if len(information_list) != 0 else ""
  117. if self.inv_no == "":
  118. pt = re.compile(r'号码:(\d{8})', re.M)
  119. information_list = pt.findall(string)
  120. self.inv_no = information_list[0] if len(information_list) != 0 else ""
  121. if self.inv_id == "":
  122. if invoice_type == 1:
  123. pt = re.compile(r'(\d{12})N', re.M)
  124. else:
  125. pt = re.compile(r'(\d{10})N', re.M)
  126. information_list = pt.findall(string)
  127. self.inv_id = information_list[0] if len(information_list) != 0 else ""
  128. if self.inv_id == "":
  129. if invoice_type == 1:
  130. pt = re.compile(r'代码:(\d{12})', re.M)
  131. else:
  132. pt = re.compile(r'代码:(\d{10})', re.M)
  133. information_list = pt.findall(string)
  134. self.inv_id = information_list[0] if len(information_list) != 0 else ""
  135. if self.inv_company[1] == '':
  136. pt = re.compile(r'称:(.*?)[-*+></\d购]?[纳税]', re.M)
  137. information_list = pt.findall(string)
  138. if len(information_list) != 0:
  139. for i in range(len(self.inv_company)):
  140. if len(information_list) != 0:
  141. self.inv_company[i] = information_list.pop(0)
  142. if self.inv_identifier[1] == '':
  143. pt = re.compile(r'别号:([a-zA-Z\d]{18})', re.M)
  144. information_list = pt.findall(string)
  145. if len(information_list) != 0:
  146. for i in range(len(self.inv_identifier)):
  147. if len(information_list) != 0:
  148. self.inv_identifier[i] = information_list.pop(0)
  149. if self.inv_payee == "":
  150. pt = re.compile(r'款人:(.*)复', re.M)
  151. information_list = pt.findall(string)
  152. self.inv_payee = information_list[0] if len(information_list) != 0 else ""
  153. if self.inv_review == "":
  154. pt = re.compile(r'复核:(.*)开', re.M)
  155. information_list = pt.findall(string)
  156. self.inv_review = information_list[0] if len(information_list) != 0 else ""
  157. if self.inv_drawer == "":
  158. pt = re.compile(r'票人:(.*)', re.M)
  159. information_list = pt.findall(string)
  160. self.inv_drawer = information_list[0] if len(information_list) != 0 else ""
  161. # 信息关键字提取
  162. def getInformation(self, string, invoice_type):
  163. if self.inv_no == "":
  164. pt = re.compile(r'N[\w]?(\d{8})')
  165. information_list = pt.findall(string)
  166. self.inv_no = information_list[0] if len(information_list) != 0 else ""
  167. if self.inv_no != "":
  168. return True
  169. if self.inv_id == "":
  170. if invoice_type == 1:
  171. pt = re.compile(r'^\d{12}$')
  172. else:
  173. pt = re.compile(r'^\d{10}$')
  174. information_list = pt.findall(string)
  175. self.inv_id = information_list[0] if len(information_list) != 0 else ""
  176. if self.inv_id != "":
  177. return True
  178. if self.inv_date == "":
  179. pt = re.compile(r'(\d*年\d*月\d*日)', re.M)
  180. information_list = pt.findall(string)
  181. self.inv_date = information_list[0] if len(information_list) != 0 else ""
  182. if self.inv_date != "":
  183. return True
  184. if self.inv_payee == "":
  185. pt = re.compile(r'款人:(.*)$', re.M)
  186. information_list = pt.findall(string)
  187. self.inv_payee = information_list[0] if len(information_list) != 0 else ""
  188. if self.inv_payee != "":
  189. return True
  190. if self.inv_review == "":
  191. pt = re.compile(r'复核:(.*)$', re.M)
  192. information_list = pt.findall(string)
  193. self.inv_review = information_list[0] if len(information_list) != 0 else ""
  194. if self.inv_review != "":
  195. return True
  196. if self.inv_drawer == "":
  197. pt = re.compile(r'票人:(.*)$', re.M)
  198. information_list = pt.findall(string)
  199. self.inv_drawer = information_list[0] if len(information_list) != 0 else ""
  200. if self.inv_drawer != "":
  201. return True
  202. if self.inv_identifier[1] == '':
  203. pt = re.compile(r'^[纳税人识别号:]?([a-zA-Z\d]{18})$', re.M)
  204. information_list = pt.findall(string)
  205. if len(information_list) != 0:
  206. for i in range(len(self.inv_identifier)):
  207. if self.inv_identifier[i] == '':
  208. self.inv_identifier[i] = information_list[0]
  209. return True
  210. if self.inv_identifier[1] == '':
  211. pt = re.compile(r'称:(.*)$', re.M)
  212. information_list = pt.findall(string)
  213. if len(information_list) != 0:
  214. for i in range(len(self.inv_company)):
  215. if self.inv_company[i] == '':
  216. self.inv_company[i] = information_list[0]
  217. return True
  218. pt = re.compile(r'(\d*\.\d*)')
  219. information_list = pt.findall(string)
  220. if len(information_list) != 0:
  221. self.inv_money.append(float(information_list[0]))
  222. return True
  223. # 固定尺寸
  224. def resizeImg(self, image, height=900):
  225. h, w = image.shape[:2]
  226. pro = height / h
  227. size = (int(w * pro), int(height))
  228. img = cv2.resize(image, size)
  229. return img
  230. if __name__ == '__main__':
  231. ocr = OCRSystem()
  232. image_path = [
  233. './doc/imgs/11.jpg',
  234. './doc/imgs/12.jpg',
  235. ]
  236. res = ocr.predict(paths=image_path)
  237. print(res)