Skip to content

Commit 810b711

Browse files
authored
Merge pull request open-mmlab#54 from hellock/hotfix
Bug fix for ConvFCBBoxHead arguments
2 parents c8cc01e + 8e09835 commit 810b711

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

mmdet/models/bbox_heads/convfc_bbox_head.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self,
2222
num_reg_fcs=0,
2323
conv_out_channels=256,
2424
fc_out_channels=1024,
25+
normalize=None,
2526
*args,
2627
**kwargs):
2728
super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
@@ -41,6 +42,8 @@ def __init__(self,
4142
self.num_reg_fcs = num_reg_fcs
4243
self.conv_out_channels = conv_out_channels
4344
self.fc_out_channels = fc_out_channels
45+
self.normalize = normalize
46+
self.with_bias = normalize is None
4447

4548
# add shared convs and fcs
4649
self.shared_convs, self.shared_fcs, last_layer_dim = \

mmdet/models/utils/conv_module.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
import torch.nn as nn
4+
from mmcv.cnn import kaiming_init, constant_init
45

56
from .norm import build_norm_layer
67

@@ -51,15 +52,8 @@ def __init__(self,
5152
self.groups = self.conv.groups
5253

5354
if self.with_norm:
54-
# self.norm_type, self.norm_params = parse_norm(normalize)
55-
# assert self.norm_type in [None, 'BN', 'SyncBN', 'GN', 'SN']
56-
# self.Norm2d = norm_cfg[self.norm_type]
57-
if self.activate_last:
58-
self.norm = build_norm_layer(normalize, out_channels)
59-
# self.norm = self.Norm2d(out_channels, **self.norm_params)
60-
else:
61-
self.norm = build_norm_layer(normalize, in_channels)
62-
# self.norm = self.Norm2d(in_channels, **self.norm_params)
55+
norm_channels = out_channels if self.activate_last else in_channels
56+
self.norm = build_norm_layer(normalize, norm_channels)
6357

6458
if self.with_activatation:
6559
assert activation in ['relu'], 'Only ReLU supported.'
@@ -71,13 +65,9 @@ def __init__(self,
7165

7266
def init_weights(self):
7367
nonlinearity = 'relu' if self.activation is None else self.activation
74-
nn.init.kaiming_normal_(
75-
self.conv.weight, mode='fan_out', nonlinearity=nonlinearity)
76-
if self.with_bias:
77-
nn.init.constant_(self.conv.bias, 0)
68+
kaiming_init(self.conv, nonlinearity=nonlinearity)
7869
if self.with_norm:
79-
nn.init.constant_(self.norm.weight, 1)
80-
nn.init.constant_(self.norm.bias, 0)
70+
constant_init(self.norm, 1, bias=0)
8171

8272
def forward(self, x, activate=True, norm=True):
8373
if self.activate_last:

0 commit comments

Comments
 (0)