rnn.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import nn
  18. from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
  19. class Im2Seq(nn.Layer):
  20. def __init__(self, in_channels, **kwargs):
  21. super().__init__()
  22. self.out_channels = in_channels
  23. def forward(self, x):
  24. B, C, H, W = x.shape
  25. assert H == 1
  26. x = x.squeeze(axis=2)
  27. x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  28. return x
  29. class EncoderWithRNN(nn.Layer):
  30. def __init__(self, in_channels, hidden_size):
  31. super(EncoderWithRNN, self).__init__()
  32. self.out_channels = hidden_size * 2
  33. self.lstm = nn.LSTM(
  34. in_channels, hidden_size, direction='bidirectional', num_layers=2)
  35. def forward(self, x):
  36. x, _ = self.lstm(x)
  37. return x
  38. class EncoderWithFC(nn.Layer):
  39. def __init__(self, in_channels, hidden_size):
  40. super(EncoderWithFC, self).__init__()
  41. self.out_channels = hidden_size
  42. weight_attr, bias_attr = get_para_bias_attr(
  43. l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea')
  44. self.fc = nn.Linear(
  45. in_channels,
  46. hidden_size,
  47. weight_attr=weight_attr,
  48. bias_attr=bias_attr,
  49. name='reduce_encoder_fea')
  50. def forward(self, x):
  51. x = self.fc(x)
  52. return x
  53. class SequenceEncoder(nn.Layer):
  54. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  55. super(SequenceEncoder, self).__init__()
  56. self.encoder_reshape = Im2Seq(in_channels)
  57. self.out_channels = self.encoder_reshape.out_channels
  58. if encoder_type == 'reshape':
  59. self.only_reshape = True
  60. else:
  61. support_encoder_dict = {
  62. 'reshape': Im2Seq,
  63. 'fc': EncoderWithFC,
  64. 'rnn': EncoderWithRNN
  65. }
  66. assert encoder_type in support_encoder_dict, '{} must in {}'.format(
  67. encoder_type, support_encoder_dict.keys())
  68. self.encoder = support_encoder_dict[encoder_type](
  69. self.encoder_reshape.out_channels, hidden_size)
  70. self.out_channels = self.encoder.out_channels
  71. self.only_reshape = False
  72. def forward(self, x):
  73. x = self.encoder_reshape(x)
  74. if not self.only_reshape:
  75. x = self.encoder(x)
  76. return x