Skip to content

Commit 5fcec9a

Browse files
committed
allow specifying num_offset_fcs and num_mask_fcs
1 parent 527629f commit 5fcec9a

File tree

1 file changed

+47
-25
lines changed

1 file changed

+47
-25
lines changed

mmdet/ops/dcn/modules/deform_pool.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,28 @@ def __init__(self,
4444
part_size=None,
4545
sample_per_part=4,
4646
trans_std=.0,
47+
num_offset_fcs=3,
4748
deform_fc_channels=1024):
4849
super(DeformRoIPoolingPack,
4950
self).__init__(spatial_scale, out_size, out_channels, no_trans,
5051
group_size, part_size, sample_per_part, trans_std)
5152

53+
self.num_offset_fcs = num_offset_fcs
5254
self.deform_fc_channels = deform_fc_channels
5355

5456
if not no_trans:
55-
self.offset_fc = nn.Sequential(
56-
nn.Linear(self.out_size * self.out_size * self.out_channels,
57-
self.deform_fc_channels),
58-
nn.ReLU(inplace=True),
59-
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
60-
nn.ReLU(inplace=True),
61-
nn.Linear(self.deform_fc_channels,
62-
self.out_size * self.out_size * 2))
57+
seq = []
58+
ic = self.out_size * self.out_size * self.out_channels
59+
for i in range(self.num_offset_fcs):
60+
if i < self.num_offset_fcs - 1:
61+
oc = self.deform_fc_channels
62+
else:
63+
oc = self.out_size * self.out_size * 2
64+
seq.append(nn.Linear(ic, oc))
65+
ic = oc
66+
if i < self.num_offset_fcs - 1:
67+
seq.append(nn.ReLU(inplace=True))
68+
self.offset_fc = nn.Sequential(*seq)
6369
self.offset_fc[-1].weight.data.zero_()
6470
self.offset_fc[-1].bias.data.zero_()
6571

@@ -97,33 +103,49 @@ def __init__(self,
97103
part_size=None,
98104
sample_per_part=4,
99105
trans_std=.0,
106+
num_offset_fcs=3,
107+
num_mask_fcs=2,
100108
deform_fc_channels=1024):
101109
super(ModulatedDeformRoIPoolingPack, self).__init__(
102110
spatial_scale, out_size, out_channels, no_trans, group_size,
103111
part_size, sample_per_part, trans_std)
104112

113+
self.num_offset_fcs = num_offset_fcs
114+
self.num_mask_fcs = num_mask_fcs
105115
self.deform_fc_channels = deform_fc_channels
106116

107117
if not no_trans:
108-
self.offset_fc = nn.Sequential(
109-
nn.Linear(self.out_size * self.out_size * self.out_channels,
110-
self.deform_fc_channels),
111-
nn.ReLU(inplace=True),
112-
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
113-
nn.ReLU(inplace=True),
114-
nn.Linear(self.deform_fc_channels,
115-
self.out_size * self.out_size * 2))
118+
offset_fc_seq = []
119+
ic = self.out_size * self.out_size * self.out_channels
120+
for i in range(self.num_offset_fcs):
121+
if i < self.num_offset_fcs - 1:
122+
oc = self.deform_fc_channels
123+
else:
124+
oc = self.out_size * self.out_size * 2
125+
offset_fc_seq.append(nn.Linear(ic, oc))
126+
ic = oc
127+
if i < self.num_offset_fcs - 1:
128+
offset_fc_seq.append(nn.ReLU(inplace=True))
129+
self.offset_fc = nn.Sequential(*offset_fc_seq)
116130
self.offset_fc[-1].weight.data.zero_()
117131
self.offset_fc[-1].bias.data.zero_()
118-
self.mask_fc = nn.Sequential(
119-
nn.Linear(self.out_size * self.out_size * self.out_channels,
120-
self.deform_fc_channels),
121-
nn.ReLU(inplace=True),
122-
nn.Linear(self.deform_fc_channels,
123-
self.out_size * self.out_size * 1),
124-
nn.Sigmoid())
125-
self.mask_fc[2].weight.data.zero_()
126-
self.mask_fc[2].bias.data.zero_()
132+
133+
mask_fc_seq = []
134+
ic = self.out_size * self.out_size * self.out_channels
135+
for i in range(self.num_mask_fcs):
136+
if i < self.num_mask_fcs - 1:
137+
oc = self.deform_fc_channels
138+
else:
139+
oc = self.out_size * self.out_size
140+
mask_fc_seq.append(nn.Linear(ic, oc))
141+
ic = oc
142+
if i < self.num_mask_fcs - 1:
143+
mask_fc_seq.append(nn.ReLU(inplace=True))
144+
else:
145+
mask_fc_seq.append(nn.Sigmoid())
146+
self.mask_fc = nn.Sequential(*mask_fc_seq)
147+
self.mask_fc[-2].weight.data.zero_()
148+
self.mask_fc[-2].bias.data.zero_()
127149

128150
def forward(self, data, rois):
129151
assert data.size(1) == self.out_channels

0 commit comments

Comments
 (0)