Skip to content

Commit 527629f

Browse files
committed
add DeformConvPack and refactoring
1 parent ee7e679 commit 527629f

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

mmdet/ops/dcn/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from .functions.deform_conv import deform_conv, modulated_deform_conv
22
from .functions.deform_pool import deform_roi_pooling
33
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
4-
ModulatedDeformConvPack)
4+
DeformConvPack, ModulatedDeformConvPack)
55
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
66
ModulatedDeformRoIPoolingPack)
77

88
__all__ = [
9-
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
10-
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
11-
'ModulatedDeformConvPack', 'deform_conv',
12-
'modulated_deform_conv', 'deform_roi_pooling'
9+
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
10+
'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
11+
'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
12+
'deform_roi_pooling'
1313
]

mmdet/ops/dcn/modules/deform_conv.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ def __init__(self,
1919
groups=1,
2020
deformable_groups=1,
2121
bias=False):
22-
assert not bias
2322
super(DeformConv, self).__init__()
2423

24+
assert not bias
2525
assert in_channels % groups == 0, \
2626
'in_channels {} cannot be divisible by groups {}'.format(
2727
in_channels, groups)
2828
assert out_channels % groups == 0, \
2929
'out_channels {} cannot be divisible by groups {}'.format(
3030
out_channels, groups)
31+
3132
self.in_channels = in_channels
3233
self.out_channels = out_channels
3334
self.kernel_size = _pair(kernel_size)
@@ -50,10 +51,34 @@ def reset_parameters(self):
5051
stdv = 1. / math.sqrt(n)
5152
self.weight.data.uniform_(-stdv, stdv)
5253

53-
def forward(self, input, offset):
54-
return deform_conv(input, offset, self.weight, self.stride,
55-
self.padding, self.dilation, self.groups,
56-
self.deformable_groups)
54+
def forward(self, x, offset):
55+
return deform_conv(x, offset, self.weight, self.stride, self.padding,
56+
self.dilation, self.groups, self.deformable_groups)
57+
58+
59+
class DeformConvPack(DeformConv):
60+
61+
def __init__(self, *args, **kwargs):
62+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
63+
64+
self.conv_offset = nn.Conv2d(
65+
self.in_channels,
66+
self.deformable_groups * 2 * self.kernel_size[0] *
67+
self.kernel_size[1],
68+
kernel_size=self.kernel_size,
69+
stride=_pair(self.stride),
70+
padding=_pair(self.padding),
71+
bias=True)
72+
self.init_offset()
73+
74+
def init_offset(self):
75+
self.conv_offset.weight.data.zero_()
76+
self.conv_offset.bias.data.zero_()
77+
78+
def forward(self, x):
79+
offset = self.conv_offset(x)
80+
return deform_conv(x, offset, self.weight, self.stride, self.padding,
81+
self.dilation, self.groups, self.deformable_groups)
5782

5883

5984
class ModulatedDeformConv(nn.Module):
@@ -97,30 +122,19 @@ def reset_parameters(self):
97122
if self.bias is not None:
98123
self.bias.data.zero_()
99124

100-
def forward(self, input, offset, mask):
101-
return modulated_deform_conv(
102-
input, offset, mask, self.weight, self.bias, self.stride,
103-
self.padding, self.dilation, self.groups, self.deformable_groups)
125+
def forward(self, x, offset, mask):
126+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
127+
self.stride, self.padding, self.dilation,
128+
self.groups, self.deformable_groups)
104129

105130

106131
class ModulatedDeformConvPack(ModulatedDeformConv):
107132

108-
def __init__(self,
109-
in_channels,
110-
out_channels,
111-
kernel_size,
112-
stride=1,
113-
padding=0,
114-
dilation=1,
115-
groups=1,
116-
deformable_groups=1,
117-
bias=True):
118-
super(ModulatedDeformConvPack, self).__init__(
119-
in_channels, out_channels, kernel_size, stride, padding, dilation,
120-
groups, deformable_groups, bias)
133+
def __init__(self, *args, **kwargs):
134+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
121135

122136
self.conv_offset_mask = nn.Conv2d(
123-
self.in_channels // self.groups,
137+
self.in_channels,
124138
self.deformable_groups * 3 * self.kernel_size[0] *
125139
self.kernel_size[1],
126140
kernel_size=self.kernel_size,
@@ -133,11 +147,11 @@ def init_offset(self):
133147
self.conv_offset_mask.weight.data.zero_()
134148
self.conv_offset_mask.bias.data.zero_()
135149

136-
def forward(self, input):
137-
out = self.conv_offset_mask(input)
150+
def forward(self, x):
151+
out = self.conv_offset_mask(x)
138152
o1, o2, mask = torch.chunk(out, 3, dim=1)
139153
offset = torch.cat((o1, o2), dim=1)
140154
mask = torch.sigmoid(mask)
141-
return modulated_deform_conv(
142-
input, offset, mask, self.weight, self.bias, self.stride,
143-
self.padding, self.dilation, self.groups, self.deformable_groups)
155+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
156+
self.stride, self.padding, self.dilation,
157+
self.groups, self.deformable_groups)

0 commit comments

Comments
 (0)