123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- from paddle import ParamAttr
- import paddle.nn as nn
- import paddle.nn.functional as F
- __all__ = ["ResNet"]
- class ConvBNLayer(nn.Layer):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
- super(ConvBNLayer, self).__init__()
- self.is_vd_mode = is_vd_mode
- self._pool2d_avg = nn.AvgPool2D(
- kernel_size=stride, stride=stride, padding=0, ceil_mode=True)
- self._conv = nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=1 if is_vd_mode else stride,
- padding=(kernel_size - 1) // 2,
- groups=groups,
- weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
- if name == "conv1":
- bn_name = "bn_" + name
- else:
- bn_name = "bn" + name[3:]
- self._batch_norm = nn.BatchNorm(
- out_channels,
- act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
- def forward(self, inputs):
- if self.is_vd_mode:
- inputs = self._pool2d_avg(inputs)
- y = self._conv(inputs)
- y = self._batch_norm(y)
- return y
- class BottleneckBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
- super(BottleneckBlock, self).__init__()
- self.conv0 = ConvBNLayer(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- act='relu',
- name=name + "_branch2a")
- self.conv1 = ConvBNLayer(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=3,
- stride=stride,
- act='relu',
- name=name + "_branch2b")
- self.conv2 = ConvBNLayer(
- in_channels=out_channels,
- out_channels=out_channels * 4,
- kernel_size=1,
- act=None,
- name=name + "_branch2c")
- if not shortcut:
- self.short = ConvBNLayer(
- in_channels=in_channels,
- out_channels=out_channels * 4,
- kernel_size=1,
- stride=stride,
- is_vd_mode=not if_first and stride[0] != 1,
- name=name + "_branch1")
- self.shortcut = shortcut
- def forward(self, inputs):
- y = self.conv0(inputs)
- conv1 = self.conv1(y)
- conv2 = self.conv2(conv1)
- if self.shortcut:
- short = inputs
- else:
- short = self.short(inputs)
- y = paddle.add(x=short, y=conv2)
- y = F.relu(y)
- return y
- class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
- super(BasicBlock, self).__init__()
- self.stride = stride
- self.conv0 = ConvBNLayer(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=3,
- stride=stride,
- act='relu',
- name=name + "_branch2a")
- self.conv1 = ConvBNLayer(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=3,
- act=None,
- name=name + "_branch2b")
- if not shortcut:
- self.short = ConvBNLayer(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=stride,
- is_vd_mode=not if_first and stride[0] != 1,
- name=name + "_branch1")
- self.shortcut = shortcut
- def forward(self, inputs):
- y = self.conv0(inputs)
- conv1 = self.conv1(y)
- if self.shortcut:
- short = inputs
- else:
- short = self.short(inputs)
- y = paddle.add(x=short, y=conv1)
- y = F.relu(y)
- return y
- class ResNet(nn.Layer):
- def __init__(self, in_channels=3, layers=50, **kwargs):
- super(ResNet, self).__init__()
- self.layers = layers
- supported_layers = [18, 34, 50, 101, 152, 200]
- assert layers in supported_layers, \
- "supported layers are {} but input layer is {}".format(
- supported_layers, layers)
- if layers == 18:
- depth = [2, 2, 2, 2]
- elif layers == 34 or layers == 50:
- depth = [3, 4, 6, 3]
- elif layers == 101:
- depth = [3, 4, 23, 3]
- elif layers == 152:
- depth = [3, 8, 36, 3]
- elif layers == 200:
- depth = [3, 12, 48, 3]
- num_channels = [64, 256, 512,
- 1024] if layers >= 50 else [64, 64, 128, 256]
- num_filters = [64, 128, 256, 512]
- self.conv1_1 = ConvBNLayer(
- in_channels=in_channels,
- out_channels=32,
- kernel_size=3,
- stride=1,
- act='relu',
- name="conv1_1")
- self.conv1_2 = ConvBNLayer(
- in_channels=32,
- out_channels=32,
- kernel_size=3,
- stride=1,
- act='relu',
- name="conv1_2")
- self.conv1_3 = ConvBNLayer(
- in_channels=32,
- out_channels=64,
- kernel_size=3,
- stride=1,
- act='relu',
- name="conv1_3")
- self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
- self.block_list = []
- if layers >= 50:
- for block in range(len(depth)):
- shortcut = False
- for i in range(depth[block]):
- if layers in [101, 152, 200] and block == 2:
- if i == 0:
- conv_name = "res" + str(block + 2) + "a"
- else:
- conv_name = "res" + str(block + 2) + "b" + str(i)
- else:
- conv_name = "res" + str(block + 2) + chr(97 + i)
- if i == 0 and block != 0:
- stride = (2, 1)
- else:
- stride = (1, 1)
- bottleneck_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
- BottleneckBlock(
- in_channels=num_channels[block]
- if i == 0 else num_filters[block] * 4,
- out_channels=num_filters[block],
- stride=stride,
- shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
- shortcut = True
- self.block_list.append(bottleneck_block)
- self.out_channels = num_filters[block]
- else:
- for block in range(len(depth)):
- shortcut = False
- for i in range(depth[block]):
- conv_name = "res" + str(block + 2) + chr(97 + i)
- if i == 0 and block != 0:
- stride = (2, 1)
- else:
- stride = (1, 1)
- basic_block = self.add_sublayer(
- 'bb_%d_%d' % (block, i),
- BasicBlock(
- in_channels=num_channels[block]
- if i == 0 else num_filters[block],
- out_channels=num_filters[block],
- stride=stride,
- shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
- shortcut = True
- self.block_list.append(basic_block)
- self.out_channels = num_filters[block]
- self.out_pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
- def forward(self, inputs):
- y = self.conv1_1(inputs)
- y = self.conv1_2(y)
- y = self.conv1_3(y)
- y = self.pool2d_max(y)
- for block in self.block_list:
- y = block(y)
- y = self.out_pool(y)
- return y
|