rec_postprocess.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import string
  16. import paddle
  17. from paddle.nn import functional as F
  18. class BaseRecLabelDecode(object):
  19. """ Convert between text-label and text-index """
  20. def __init__(self,
  21. character_dict_path=None,
  22. character_type='ch',
  23. use_space_char=False):
  24. support_character_type = [
  25. 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
  26. 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
  27. 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
  28. 'ne', 'EN'
  29. ]
  30. assert character_type in support_character_type, "Only {} are supported now but get {}".format(
  31. support_character_type, character_type)
  32. self.beg_str = "sos"
  33. self.end_str = "eos"
  34. if character_type == "en":
  35. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  36. dict_character = list(self.character_str)
  37. elif character_type == "EN_symbol":
  38. # same with ASTER setting (use 94 char).
  39. self.character_str = string.printable[:-6]
  40. dict_character = list(self.character_str)
  41. elif character_type in support_character_type:
  42. self.character_str = ""
  43. assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
  44. character_type)
  45. with open(character_dict_path, "rb") as fin:
  46. lines = fin.readlines()
  47. for line in lines:
  48. line = line.decode('utf-8').strip("\n").strip("\r\n")
  49. self.character_str += line
  50. if use_space_char:
  51. self.character_str += " "
  52. dict_character = list(self.character_str)
  53. else:
  54. raise NotImplementedError
  55. self.character_type = character_type
  56. dict_character = self.add_special_char(dict_character)
  57. self.dict = {}
  58. for i, char in enumerate(dict_character):
  59. self.dict[char] = i
  60. self.character = dict_character
  61. def add_special_char(self, dict_character):
  62. return dict_character
  63. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  64. """ convert text-index into text-label. """
  65. result_list = []
  66. ignored_tokens = self.get_ignored_tokens()
  67. batch_size = len(text_index)
  68. for batch_idx in range(batch_size):
  69. char_list = []
  70. conf_list = []
  71. for idx in range(len(text_index[batch_idx])):
  72. if text_index[batch_idx][idx] in ignored_tokens:
  73. continue
  74. if is_remove_duplicate:
  75. # only for predict
  76. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  77. batch_idx][idx]:
  78. continue
  79. char_list.append(self.character[int(text_index[batch_idx][
  80. idx])])
  81. if text_prob is not None:
  82. conf_list.append(text_prob[batch_idx][idx])
  83. else:
  84. conf_list.append(1)
  85. text = ''.join(char_list)
  86. result_list.append((text, np.mean(conf_list)))
  87. return result_list
  88. def get_ignored_tokens(self):
  89. return [0] # for ctc blank
  90. class CTCLabelDecode(BaseRecLabelDecode):
  91. """ Convert between text-label and text-index """
  92. def __init__(self,
  93. character_dict_path=None,
  94. character_type='ch',
  95. use_space_char=False,
  96. **kwargs):
  97. super(CTCLabelDecode, self).__init__(character_dict_path,
  98. character_type, use_space_char)
  99. def __call__(self, preds, label=None, *args, **kwargs):
  100. if isinstance(preds, paddle.Tensor):
  101. preds = preds.numpy()
  102. preds_idx = preds.argmax(axis=2)
  103. preds_prob = preds.max(axis=2)
  104. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  105. if label is None:
  106. return text
  107. label = self.decode(label)
  108. return text, label
  109. def add_special_char(self, dict_character):
  110. dict_character = ['blank'] + dict_character
  111. return dict_character
  112. class AttnLabelDecode(BaseRecLabelDecode):
  113. """ Convert between text-label and text-index """
  114. def __init__(self,
  115. character_dict_path=None,
  116. character_type='ch',
  117. use_space_char=False,
  118. **kwargs):
  119. super(AttnLabelDecode, self).__init__(character_dict_path,
  120. character_type, use_space_char)
  121. def add_special_char(self, dict_character):
  122. self.beg_str = "sos"
  123. self.end_str = "eos"
  124. dict_character = dict_character
  125. dict_character = [self.beg_str] + dict_character + [self.end_str]
  126. return dict_character
  127. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  128. """ convert text-index into text-label. """
  129. result_list = []
  130. ignored_tokens = self.get_ignored_tokens()
  131. [beg_idx, end_idx] = self.get_ignored_tokens()
  132. batch_size = len(text_index)
  133. for batch_idx in range(batch_size):
  134. char_list = []
  135. conf_list = []
  136. for idx in range(len(text_index[batch_idx])):
  137. if text_index[batch_idx][idx] in ignored_tokens:
  138. continue
  139. if int(text_index[batch_idx][idx]) == int(end_idx):
  140. break
  141. if is_remove_duplicate:
  142. # only for predict
  143. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  144. batch_idx][idx]:
  145. continue
  146. char_list.append(self.character[int(text_index[batch_idx][
  147. idx])])
  148. if text_prob is not None:
  149. conf_list.append(text_prob[batch_idx][idx])
  150. else:
  151. conf_list.append(1)
  152. text = ''.join(char_list)
  153. result_list.append((text, np.mean(conf_list)))
  154. return result_list
  155. def __call__(self, preds, label=None, *args, **kwargs):
  156. """
  157. text = self.decode(text)
  158. if label is None:
  159. return text
  160. else:
  161. label = self.decode(label, is_remove_duplicate=False)
  162. return text, label
  163. """
  164. if isinstance(preds, paddle.Tensor):
  165. preds = preds.numpy()
  166. preds_idx = preds.argmax(axis=2)
  167. preds_prob = preds.max(axis=2)
  168. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  169. if label is None:
  170. return text
  171. label = self.decode(label, is_remove_duplicate=False)
  172. return text, label
  173. def get_ignored_tokens(self):
  174. beg_idx = self.get_beg_end_flag_idx("beg")
  175. end_idx = self.get_beg_end_flag_idx("end")
  176. return [beg_idx, end_idx]
  177. def get_beg_end_flag_idx(self, beg_or_end):
  178. if beg_or_end == "beg":
  179. idx = np.array(self.dict[self.beg_str])
  180. elif beg_or_end == "end":
  181. idx = np.array(self.dict[self.end_str])
  182. else:
  183. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  184. % beg_or_end
  185. return idx
  186. class SRNLabelDecode(BaseRecLabelDecode):
  187. """ Convert between text-label and text-index """
  188. def __init__(self,
  189. character_dict_path=None,
  190. character_type='en',
  191. use_space_char=False,
  192. **kwargs):
  193. super(SRNLabelDecode, self).__init__(character_dict_path,
  194. character_type, use_space_char)
  195. def __call__(self, preds, label=None, *args, **kwargs):
  196. pred = preds['predict']
  197. char_num = len(self.character_str) + 2
  198. if isinstance(pred, paddle.Tensor):
  199. pred = pred.numpy()
  200. pred = np.reshape(pred, [-1, char_num])
  201. preds_idx = np.argmax(pred, axis=1)
  202. preds_prob = np.max(pred, axis=1)
  203. preds_idx = np.reshape(preds_idx, [-1, 25])
  204. preds_prob = np.reshape(preds_prob, [-1, 25])
  205. text = self.decode(preds_idx, preds_prob)
  206. if label is None:
  207. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  208. return text
  209. label = self.decode(label)
  210. return text, label
  211. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  212. """ convert text-index into text-label. """
  213. result_list = []
  214. ignored_tokens = self.get_ignored_tokens()
  215. batch_size = len(text_index)
  216. for batch_idx in range(batch_size):
  217. char_list = []
  218. conf_list = []
  219. for idx in range(len(text_index[batch_idx])):
  220. if text_index[batch_idx][idx] in ignored_tokens:
  221. continue
  222. if is_remove_duplicate:
  223. # only for predict
  224. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  225. batch_idx][idx]:
  226. continue
  227. char_list.append(self.character[int(text_index[batch_idx][
  228. idx])])
  229. if text_prob is not None:
  230. conf_list.append(text_prob[batch_idx][idx])
  231. else:
  232. conf_list.append(1)
  233. text = ''.join(char_list)
  234. result_list.append((text, np.mean(conf_list)))
  235. return result_list
  236. def add_special_char(self, dict_character):
  237. dict_character = dict_character + [self.beg_str, self.end_str]
  238. return dict_character
  239. def get_ignored_tokens(self):
  240. beg_idx = self.get_beg_end_flag_idx("beg")
  241. end_idx = self.get_beg_end_flag_idx("end")
  242. return [beg_idx, end_idx]
  243. def get_beg_end_flag_idx(self, beg_or_end):
  244. if beg_or_end == "beg":
  245. idx = np.array(self.dict[self.beg_str])
  246. elif beg_or_end == "end":
  247. idx = np.array(self.dict[self.end_str])
  248. else:
  249. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  250. % beg_or_end
  251. return idx