# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re
import sys
sys.path.insert(0, ".")

import time

# from paddlehub.common.logger import logger
# from paddlehub.module.module import moduleinfo, runnable, serving
import cv2
import numpy as np
# import paddlehub as hub

from tools.infer.utility import base64_to_cv2
from tools.infer.predict_system import TextSystem


class OCRSystem:
    def __init__(self, use_gpu=False, enable_mkldnn=False):
        """
        initialize with the necessary elements
        """
        from deploy.hubserving.ocr_system.params import read_params
        cfg = read_params()

        cfg.use_gpu = use_gpu
        if use_gpu:
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
                print("use gpu: ", use_gpu)
                print("CUDA_VISIBLE_DEVICES: ", _places)
                cfg.gpu_mem = 8000
            except:
                raise RuntimeError(
                    "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
                )
        cfg.ir_optim = True
        cfg.enable_mkldnn = enable_mkldnn

        self.text_sys = TextSystem(cfg)

    def read_images(self, paths=[]):
        images = []
        for img_path in paths:
            assert os.path.isfile(
                img_path), "The {} isn't a valid file.".format(img_path)
            img = cv2.imread(img_path)
            if img is None:
                print("error in loading image:{}".format(img_path))
                continue
            images.append(img)
        return images

    def predict(self, images=[], paths=[], **kwargs):
        """
        Get the chinese texts in the predicted images.
        Args:
            images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
            paths (list[str]): The paths of images. If paths not images
        Returns:
            res (list): The result of chinese texts and save path of images.
        """

        if images != [] and isinstance(images, list) and paths == []:
            predicted_data = images
        elif images == [] and isinstance(paths, list) and paths != []:
            predicted_data = self.read_images(paths)
        else:
            raise TypeError("The input data is inconsistent with expectations.")

        assert predicted_data != [], "There is not any image to be predicted. Please check the input data."

        all_results = []
        for img in predicted_data:

            # 初始化关键字
            self.inv_no = ""
            self.inv_id = ""
            self.inv_date = ""
            self.inv_money = []
            self.inv_payee = ""
            self.inv_review = ""
            self.inv_drawer = ""
            self.inv_company = ["", ""]
            self.inv_identifier = ["", ""]

            if img is None:
                print("error in loading image")
                all_results.append([])
                continue

            # 图片预处理
            img = self.resizeImg(img, 640)

            starttime = time.time()
            dt_boxes, rec_res = self.text_sys(img)
            elapse = time.time() - starttime
            print("Predict time: {}".format(elapse))

            dt_num = len(dt_boxes)
            rec_res_final = []

            text_list = []  # 结果集
            for dno in range(dt_num):
                text, score = rec_res[dno]
                if score > 0.8:
                    text_list.append(text)
                    # print(text)
                    self.getInformation(text, kwargs['invoice_type'])

            inv_text = ''.join(text_list)
            # print(inv_text)
            self.getInformationAgain(inv_text, kwargs['invoice_type'])
            all_results.append({
                'no': self.inv_no,
                'id': self.inv_id,
                'date': self.inv_date,
                'money': (max(self.inv_money) if len(self.inv_money) > 0 else ""),
                'payee': self.inv_payee,
                'review': self.inv_review,
                'drawer': self.inv_drawer,
                'company': self.inv_company,
                'identifier': self.inv_identifier
            })
        return all_results

    def getInformationAgain(self, string, invoice_type):
        if self.inv_no == "":
            pt = re.compile(r'N[\w|\s]?(\d{8})', re.M)
            information_list = pt.findall(string)
            self.inv_no = information_list[0] if len(information_list) != 0 else ""
        if self.inv_no == "":
            pt = re.compile(r'号码:(\d{8})', re.M)
            information_list = pt.findall(string)
            self.inv_no = information_list[0] if len(information_list) != 0 else ""

        if self.inv_id == "":
            if invoice_type == 1:
                pt = re.compile(r'(\d{12})N', re.M)
            else:
                pt = re.compile(r'(\d{10})N', re.M)
            information_list = pt.findall(string)
            self.inv_id = information_list[0] if len(information_list) != 0 else ""
        if self.inv_id == "":
            if invoice_type == 1:
                pt = re.compile(r'代码:(\d{12})', re.M)
            else:
                pt = re.compile(r'代码:(\d{10})', re.M)
            information_list = pt.findall(string)
            self.inv_id = information_list[0] if len(information_list) != 0 else ""

        if self.inv_company[1] == '':
            pt = re.compile(r'称:(.*?)[-*+></\d购]?[纳税]', re.M)
            information_list = pt.findall(string)
            if len(information_list) != 0:
                for i in range(len(self.inv_company)):
                    if len(information_list) != 0:
                        self.inv_company[i] = information_list.pop(0)

        if self.inv_identifier[1] == '':
            pt = re.compile(r'别号:([a-zA-Z\d]{18})', re.M)
            information_list = pt.findall(string)
            if len(information_list) != 0:
                for i in range(len(self.inv_identifier)):
                    if len(information_list) != 0:
                        self.inv_identifier[i] = information_list.pop(0)


        if self.inv_payee == "":
            pt = re.compile(r'款人:(.*)复', re.M)
            information_list = pt.findall(string)
            self.inv_payee = information_list[0] if len(information_list) != 0 else ""

        if self.inv_review == "":
            pt = re.compile(r'复核:(.*)开', re.M)
            information_list = pt.findall(string)
            self.inv_review = information_list[0] if len(information_list) != 0 else ""

        if self.inv_drawer == "":
            pt = re.compile(r'票人:(.*)', re.M)
            information_list = pt.findall(string)
            self.inv_drawer = information_list[0] if len(information_list) != 0 else ""

    # 信息关键字提取
    def getInformation(self, string, invoice_type):
        if self.inv_no == "":
            pt = re.compile(r'N[\w]?(\d{8})')
            information_list = pt.findall(string)
            self.inv_no = information_list[0] if len(information_list) != 0 else ""
            if self.inv_no != "":
                return True

        if self.inv_id == "":
            if invoice_type == 1:
                pt = re.compile(r'^\d{12}$')
            else:
                pt = re.compile(r'^\d{10}$')
            information_list = pt.findall(string)
            self.inv_id = information_list[0] if len(information_list) != 0 else ""
            if self.inv_id != "":
                return True

        if self.inv_date == "":
            pt = re.compile(r'(\d*年\d*月\d*日)', re.M)
            information_list = pt.findall(string)
            self.inv_date = information_list[0] if len(information_list) != 0 else ""
            if self.inv_date != "":
                return True

        if self.inv_payee == "":
            pt = re.compile(r'款人:(.*)$', re.M)
            information_list = pt.findall(string)
            self.inv_payee = information_list[0] if len(information_list) != 0 else ""
            if self.inv_payee != "":
                return True

        if self.inv_review == "":
            pt = re.compile(r'复核:(.*)$', re.M)
            information_list = pt.findall(string)
            self.inv_review = information_list[0] if len(information_list) != 0 else ""
            if self.inv_review != "":
                return True

        if self.inv_drawer == "":
            pt = re.compile(r'票人:(.*)$', re.M)
            information_list = pt.findall(string)
            self.inv_drawer = information_list[0] if len(information_list) != 0 else ""
            if self.inv_drawer != "":
                return True

        if self.inv_identifier[1] == '':
            pt = re.compile(r'^[纳税人识别号:]?([a-zA-Z\d]{18})$', re.M)
            information_list = pt.findall(string)
            if len(information_list) != 0:
                for i in range(len(self.inv_identifier)):
                    if self.inv_identifier[i] == '':
                        self.inv_identifier[i] = information_list[0]
                        return True

        if self.inv_identifier[1] == '':
            pt = re.compile(r'称:(.*)$', re.M)
            information_list = pt.findall(string)
            if len(information_list) != 0:
                for i in range(len(self.inv_company)):
                    if self.inv_company[i] == '':
                        self.inv_company[i] = information_list[0]
                        return True

        pt = re.compile(r'(\d*\.\d*)')
        information_list = pt.findall(string)
        if len(information_list) != 0:
            self.inv_money.append(float(information_list[0]))
        return True

    # 固定尺寸
    def resizeImg(self, image, height=900):
        h, w = image.shape[:2]
        pro = height / h
        size = (int(w * pro), int(height))
        img = cv2.resize(image, size)
        return img


if __name__ == '__main__':
    ocr = OCRSystem()
    image_path = [
        './doc/imgs/11.jpg',
        './doc/imgs/12.jpg',
    ]
    res = ocr.predict(paths=image_path)
    print(res)