rec_srn_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import nn, ParamAttr
  20. from paddle.nn import functional as F
  21. import paddle.fluid as fluid
  22. import numpy as np
  23. from .self_attention import WrapEncoderForFeature
  24. from .self_attention import WrapEncoder
  25. from paddle.static import Program
  26. from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
  27. import paddle.fluid.framework as framework
  28. from collections import OrderedDict
  29. gradient_clip = 10
  30. class PVAM(nn.Layer):
  31. def __init__(self, in_channels, char_num, max_text_length, num_heads,
  32. num_encoder_tus, hidden_dims):
  33. super(PVAM, self).__init__()
  34. self.char_num = char_num
  35. self.max_length = max_text_length
  36. self.num_heads = num_heads
  37. self.num_encoder_TUs = num_encoder_tus
  38. self.hidden_dims = hidden_dims
  39. # Transformer encoder
  40. t = 256
  41. c = 512
  42. self.wrap_encoder_for_feature = WrapEncoderForFeature(
  43. src_vocab_size=1,
  44. max_length=t,
  45. n_layer=self.num_encoder_TUs,
  46. n_head=self.num_heads,
  47. d_key=int(self.hidden_dims / self.num_heads),
  48. d_value=int(self.hidden_dims / self.num_heads),
  49. d_model=self.hidden_dims,
  50. d_inner_hid=self.hidden_dims,
  51. prepostprocess_dropout=0.1,
  52. attention_dropout=0.1,
  53. relu_dropout=0.1,
  54. preprocess_cmd="n",
  55. postprocess_cmd="da",
  56. weight_sharing=True)
  57. # PVAM
  58. self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
  59. self.fc0 = paddle.nn.Linear(
  60. in_features=in_channels,
  61. out_features=in_channels, )
  62. self.emb = paddle.nn.Embedding(
  63. num_embeddings=self.max_length, embedding_dim=in_channels)
  64. self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
  65. self.fc1 = paddle.nn.Linear(
  66. in_features=in_channels, out_features=1, bias_attr=False)
  67. def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
  68. b, c, h, w = inputs.shape
  69. conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
  70. conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
  71. # transformer encoder
  72. b, t, c = conv_features.shape
  73. enc_inputs = [conv_features, encoder_word_pos, None]
  74. word_features = self.wrap_encoder_for_feature(enc_inputs)
  75. # pvam
  76. b, t, c = word_features.shape
  77. word_features = self.fc0(word_features)
  78. word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
  79. word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
  80. word_pos_feature = self.emb(gsrm_word_pos)
  81. word_pos_feature_ = paddle.reshape(word_pos_feature,
  82. [-1, self.max_length, 1, c])
  83. word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
  84. y = word_pos_feature_ + word_features_
  85. y = F.tanh(y)
  86. attention_weight = self.fc1(y)
  87. attention_weight = paddle.reshape(
  88. attention_weight, shape=[-1, self.max_length, t])
  89. attention_weight = F.softmax(attention_weight, axis=-1)
  90. pvam_features = paddle.matmul(attention_weight,
  91. word_features) #[b, max_length, c]
  92. return pvam_features
  93. class GSRM(nn.Layer):
  94. def __init__(self, in_channels, char_num, max_text_length, num_heads,
  95. num_encoder_tus, num_decoder_tus, hidden_dims):
  96. super(GSRM, self).__init__()
  97. self.char_num = char_num
  98. self.max_length = max_text_length
  99. self.num_heads = num_heads
  100. self.num_encoder_TUs = num_encoder_tus
  101. self.num_decoder_TUs = num_decoder_tus
  102. self.hidden_dims = hidden_dims
  103. self.fc0 = paddle.nn.Linear(
  104. in_features=in_channels, out_features=self.char_num)
  105. self.wrap_encoder0 = WrapEncoder(
  106. src_vocab_size=self.char_num + 1,
  107. max_length=self.max_length,
  108. n_layer=self.num_decoder_TUs,
  109. n_head=self.num_heads,
  110. d_key=int(self.hidden_dims / self.num_heads),
  111. d_value=int(self.hidden_dims / self.num_heads),
  112. d_model=self.hidden_dims,
  113. d_inner_hid=self.hidden_dims,
  114. prepostprocess_dropout=0.1,
  115. attention_dropout=0.1,
  116. relu_dropout=0.1,
  117. preprocess_cmd="n",
  118. postprocess_cmd="da",
  119. weight_sharing=True)
  120. self.wrap_encoder1 = WrapEncoder(
  121. src_vocab_size=self.char_num + 1,
  122. max_length=self.max_length,
  123. n_layer=self.num_decoder_TUs,
  124. n_head=self.num_heads,
  125. d_key=int(self.hidden_dims / self.num_heads),
  126. d_value=int(self.hidden_dims / self.num_heads),
  127. d_model=self.hidden_dims,
  128. d_inner_hid=self.hidden_dims,
  129. prepostprocess_dropout=0.1,
  130. attention_dropout=0.1,
  131. relu_dropout=0.1,
  132. preprocess_cmd="n",
  133. postprocess_cmd="da",
  134. weight_sharing=True)
  135. self.mul = lambda x: paddle.matmul(x=x,
  136. y=self.wrap_encoder0.prepare_decoder.emb0.weight,
  137. transpose_y=True)
  138. def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
  139. gsrm_slf_attn_bias2):
  140. # ===== GSRM Visual-to-semantic embedding block =====
  141. b, t, c = inputs.shape
  142. pvam_features = paddle.reshape(inputs, [-1, c])
  143. word_out = self.fc0(pvam_features)
  144. word_ids = paddle.argmax(F.softmax(word_out), axis=1)
  145. word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
  146. #===== GSRM Semantic reasoning block =====
  147. """
  148. This module is achieved through bi-transformers,
  149. ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
  150. """
  151. pad_idx = self.char_num
  152. word1 = paddle.cast(word_ids, "float32")
  153. word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
  154. word1 = paddle.cast(word1, "int64")
  155. word1 = word1[:, :-1, :]
  156. word2 = word_ids
  157. enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
  158. enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
  159. gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
  160. gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
  161. gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
  162. value=0.,
  163. data_format="NLC")
  164. gsrm_feature2 = gsrm_feature2[:, 1:, ]
  165. gsrm_features = gsrm_feature1 + gsrm_feature2
  166. gsrm_out = self.mul(gsrm_features)
  167. b, t, c = gsrm_out.shape
  168. gsrm_out = paddle.reshape(gsrm_out, [-1, c])
  169. return gsrm_features, word_out, gsrm_out
  170. class VSFD(nn.Layer):
  171. def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
  172. super(VSFD, self).__init__()
  173. self.char_num = char_num
  174. self.fc0 = paddle.nn.Linear(
  175. in_features=in_channels * 2, out_features=pvam_ch)
  176. self.fc1 = paddle.nn.Linear(
  177. in_features=pvam_ch, out_features=self.char_num)
  178. def forward(self, pvam_feature, gsrm_feature):
  179. b, t, c1 = pvam_feature.shape
  180. b, t, c2 = gsrm_feature.shape
  181. combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
  182. img_comb_feature_ = paddle.reshape(
  183. combine_feature_, shape=[-1, c1 + c2])
  184. img_comb_feature_map = self.fc0(img_comb_feature_)
  185. img_comb_feature_map = F.sigmoid(img_comb_feature_map)
  186. img_comb_feature_map = paddle.reshape(
  187. img_comb_feature_map, shape=[-1, t, c1])
  188. combine_feature = img_comb_feature_map * pvam_feature + (
  189. 1.0 - img_comb_feature_map) * gsrm_feature
  190. img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
  191. out = self.fc1(img_comb_feature)
  192. return out
  193. class SRNHead(nn.Layer):
  194. def __init__(self, in_channels, out_channels, max_text_length, num_heads,
  195. num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
  196. super(SRNHead, self).__init__()
  197. self.char_num = out_channels
  198. self.max_length = max_text_length
  199. self.num_heads = num_heads
  200. self.num_encoder_TUs = num_encoder_TUs
  201. self.num_decoder_TUs = num_decoder_TUs
  202. self.hidden_dims = hidden_dims
  203. self.pvam = PVAM(
  204. in_channels=in_channels,
  205. char_num=self.char_num,
  206. max_text_length=self.max_length,
  207. num_heads=self.num_heads,
  208. num_encoder_tus=self.num_encoder_TUs,
  209. hidden_dims=self.hidden_dims)
  210. self.gsrm = GSRM(
  211. in_channels=in_channels,
  212. char_num=self.char_num,
  213. max_text_length=self.max_length,
  214. num_heads=self.num_heads,
  215. num_encoder_tus=self.num_encoder_TUs,
  216. num_decoder_tus=self.num_decoder_TUs,
  217. hidden_dims=self.hidden_dims)
  218. self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
  219. self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
  220. def forward(self, inputs, others):
  221. encoder_word_pos = others[0]
  222. gsrm_word_pos = others[1]
  223. gsrm_slf_attn_bias1 = others[2]
  224. gsrm_slf_attn_bias2 = others[3]
  225. pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
  226. gsrm_feature, word_out, gsrm_out = self.gsrm(
  227. pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
  228. gsrm_slf_attn_bias2)
  229. final_out = self.vsfd(pvam_feature, gsrm_feature)
  230. if not self.training:
  231. final_out = F.softmax(final_out, axis=1)
  232. _, decoded_out = paddle.topk(final_out, k=1)
  233. predicts = OrderedDict([
  234. ('predict', final_out),
  235. ('pvam_feature', pvam_feature),
  236. ('decoded_out', decoded_out),
  237. ('word_out', word_out),
  238. ('gsrm_out', gsrm_out),
  239. ])
  240. return predicts