db_fpn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class DBFPN(nn.Layer):
  22. def __init__(self, in_channels, out_channels, **kwargs):
  23. super(DBFPN, self).__init__()
  24. self.out_channels = out_channels
  25. weight_attr = paddle.nn.initializer.KaimingUniform()
  26. self.in2_conv = nn.Conv2D(
  27. in_channels=in_channels[0],
  28. out_channels=self.out_channels,
  29. kernel_size=1,
  30. weight_attr=ParamAttr(
  31. name='conv2d_51.w_0', initializer=weight_attr),
  32. bias_attr=False)
  33. self.in3_conv = nn.Conv2D(
  34. in_channels=in_channels[1],
  35. out_channels=self.out_channels,
  36. kernel_size=1,
  37. weight_attr=ParamAttr(
  38. name='conv2d_50.w_0', initializer=weight_attr),
  39. bias_attr=False)
  40. self.in4_conv = nn.Conv2D(
  41. in_channels=in_channels[2],
  42. out_channels=self.out_channels,
  43. kernel_size=1,
  44. weight_attr=ParamAttr(
  45. name='conv2d_49.w_0', initializer=weight_attr),
  46. bias_attr=False)
  47. self.in5_conv = nn.Conv2D(
  48. in_channels=in_channels[3],
  49. out_channels=self.out_channels,
  50. kernel_size=1,
  51. weight_attr=ParamAttr(
  52. name='conv2d_48.w_0', initializer=weight_attr),
  53. bias_attr=False)
  54. self.p5_conv = nn.Conv2D(
  55. in_channels=self.out_channels,
  56. out_channels=self.out_channels // 4,
  57. kernel_size=3,
  58. padding=1,
  59. weight_attr=ParamAttr(
  60. name='conv2d_52.w_0', initializer=weight_attr),
  61. bias_attr=False)
  62. self.p4_conv = nn.Conv2D(
  63. in_channels=self.out_channels,
  64. out_channels=self.out_channels // 4,
  65. kernel_size=3,
  66. padding=1,
  67. weight_attr=ParamAttr(
  68. name='conv2d_53.w_0', initializer=weight_attr),
  69. bias_attr=False)
  70. self.p3_conv = nn.Conv2D(
  71. in_channels=self.out_channels,
  72. out_channels=self.out_channels // 4,
  73. kernel_size=3,
  74. padding=1,
  75. weight_attr=ParamAttr(
  76. name='conv2d_54.w_0', initializer=weight_attr),
  77. bias_attr=False)
  78. self.p2_conv = nn.Conv2D(
  79. in_channels=self.out_channels,
  80. out_channels=self.out_channels // 4,
  81. kernel_size=3,
  82. padding=1,
  83. weight_attr=ParamAttr(
  84. name='conv2d_55.w_0', initializer=weight_attr),
  85. bias_attr=False)
  86. def forward(self, x):
  87. c2, c3, c4, c5 = x
  88. in5 = self.in5_conv(c5)
  89. in4 = self.in4_conv(c4)
  90. in3 = self.in3_conv(c3)
  91. in2 = self.in2_conv(c2)
  92. out4 = in4 + F.upsample(
  93. in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
  94. out3 = in3 + F.upsample(
  95. out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
  96. out2 = in2 + F.upsample(
  97. out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
  98. p5 = self.p5_conv(in5)
  99. p4 = self.p4_conv(out4)
  100. p3 = self.p3_conv(out3)
  101. p2 = self.p2_conv(out2)
  102. p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
  103. p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
  104. p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
  105. fuse = paddle.concat([p5, p4, p3, p2], axis=1)
  106. return fuse