|
| 1 | +import torch.nn as nn |
| 2 | +import torch.nn.functional as F |
| 3 | +from mmcv.cnn import ConvModule |
| 4 | + |
| 5 | +from ..builder import NECKS |
| 6 | + |
| 7 | + |
| 8 | +@NECKS.register_module() |
| 9 | +class MultiLevelNeck(nn.Module): |
| 10 | + """MultiLevelNeck. |
| 11 | +
|
| 12 | + A neck structure connect vit backbone and decoder_heads. |
| 13 | + Args: |
| 14 | + in_channels (List[int]): Number of input channels per scale. |
| 15 | + out_channels (int): Number of output channels (used at each scale). |
| 16 | + scales (List[int]): Scale factors for each input feature map. |
| 17 | + norm_cfg (dict): Config dict for normalization layer. Default: None. |
| 18 | + act_cfg (dict): Config dict for activation layer in ConvModule. |
| 19 | + Default: None. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, |
| 23 | + in_channels, |
| 24 | + out_channels, |
| 25 | + scales=[0.5, 1, 2, 4], |
| 26 | + norm_cfg=None, |
| 27 | + act_cfg=None): |
| 28 | + super(MultiLevelNeck, self).__init__() |
| 29 | + assert isinstance(in_channels, list) |
| 30 | + self.in_channels = in_channels |
| 31 | + self.out_channels = out_channels |
| 32 | + self.scales = scales |
| 33 | + self.num_outs = len(scales) |
| 34 | + self.lateral_convs = nn.ModuleList() |
| 35 | + self.convs = nn.ModuleList() |
| 36 | + for in_channel in in_channels: |
| 37 | + self.lateral_convs.append( |
| 38 | + ConvModule( |
| 39 | + in_channel, |
| 40 | + out_channels, |
| 41 | + kernel_size=1, |
| 42 | + norm_cfg=norm_cfg, |
| 43 | + act_cfg=act_cfg)) |
| 44 | + for _ in range(self.num_outs): |
| 45 | + self.convs.append( |
| 46 | + ConvModule( |
| 47 | + out_channels, |
| 48 | + out_channels, |
| 49 | + kernel_size=3, |
| 50 | + padding=1, |
| 51 | + stride=1, |
| 52 | + norm_cfg=norm_cfg, |
| 53 | + act_cfg=act_cfg)) |
| 54 | + |
| 55 | + def forward(self, inputs): |
| 56 | + assert len(inputs) == len(self.in_channels) |
| 57 | + print(inputs[0].shape) |
| 58 | + inputs = [ |
| 59 | + lateral_conv(inputs[i]) |
| 60 | + for i, lateral_conv in enumerate(self.lateral_convs) |
| 61 | + ] |
| 62 | + # for len(inputs) not equal to self.num_outs |
| 63 | + if len(inputs) == 1: |
| 64 | + inputs = [inputs[0] for _ in range(self.num_outs)] |
| 65 | + outs = [] |
| 66 | + for i in range(self.num_outs): |
| 67 | + x_resize = F.interpolate( |
| 68 | + inputs[i], scale_factor=self.scales[i], mode='bilinear') |
| 69 | + outs.append(self.convs[i](x_resize)) |
| 70 | + return tuple(outs) |
0 commit comments