Skip to content

Commit ddfb38e

Browse files
xvjiaruihellock
authored andcommitted
add pytorch 1.1.0 SyncBN support (open-mmlab#577)
* add pytorch 1.1.0 SyncBN support * change BatchNorm2d to _BatchNorm and call freeze after train * add freeze back to init function * fixed indentation typo in adding freeze * use SyncBN protect member func to set ddp_gpu_num * Update README.md update pytorch version to 1.1
1 parent c52cdd6 commit ddfb38e

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
## Introduction
55

6-
The master branch works with **PyTorch 1.0** or higher. If you would like to use PyTorch 0.4.1,
6+
The master branch works with **PyTorch 1.1** or higher. If you would like to use PyTorch 0.4.1,
77
please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
88

99
mmdetection is an open source object detection toolbox based on PyTorch. It is

mmdet/models/backbones/resnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch.nn as nn
44
import torch.utils.checkpoint as cp
5+
from torch.nn.modules.batchnorm import _BatchNorm
56

67
from mmcv.cnn import constant_init, kaiming_init
78
from mmcv.runner import load_checkpoint
@@ -437,7 +438,7 @@ def init_weights(self, pretrained=None):
437438
for m in self.modules():
438439
if isinstance(m, nn.Conv2d):
439440
kaiming_init(m)
440-
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
441+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
441442
constant_init(m, 1)
442443

443444
if self.dcn is not None:
@@ -470,8 +471,9 @@ def forward(self, x):
470471

471472
def train(self, mode=True):
472473
super(ResNet, self).train(mode)
474+
self._freeze_stages()
473475
if mode and self.norm_eval:
474476
for m in self.modules():
475477
# trick: eval have effect on BatchNorm only
476-
if isinstance(m, nn.BatchNorm2d):
478+
if isinstance(m, _BatchNorm):
477479
m.eval()

mmdet/models/utils/norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
norm_cfg = {
55
# format: layer_type: (abbreviation, module)
66
'BN': ('bn', nn.BatchNorm2d),
7-
'SyncBN': ('bn', None),
7+
'SyncBN': ('bn', nn.SyncBatchNorm),
88
'GN': ('gn', nn.GroupNorm),
99
# and potentially 'SN'
1010
}
@@ -44,6 +44,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
4444
cfg_.setdefault('eps', 1e-5)
4545
if layer_type != 'GN':
4646
layer = norm_layer(num_features, **cfg_)
47+
if layer_type == 'SyncBN':
48+
layer._specify_ddp_gpu_num(1)
4749
else:
4850
assert 'num_groups' in cfg_
4951
layer = norm_layer(num_channels=num_features, **cfg_)

0 commit comments

Comments
 (0)