Skip to content

Commit cae715a

Browse files
authored
[Fix] Convert SyncBN to BN when training on DP (open-mmlab#772)
* [Fix] Convert SyncBN to BN when training on DP. * Modify SyncBN2BN. * Add SyncBN2BN unit test. * Resolve some comments. * use mmcv official revert_sync_batchnorm * Remove local syncbn2bn unit tests. * Update mmcv version. * Fix bugs of gather model tools. * Modify warnings. * Modify docker mmcv version. * Update mmcv version table.
1 parent 5a7996d commit cae715a

File tree

9 files changed

+26
-30
lines changed

9 files changed

+26
-30
lines changed

.dev/gather_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_final_results(log_json_path, iter_num):
7575
def parse_args():
7676
parser = argparse.ArgumentParser(description='Gather benchmarked models')
7777
parser.add_argument(
78-
'-c', '--config-name', type=str, help='Process the selected config.')
78+
'-f', '--config-name', type=str, help='Process the selected config.')
7979
parser.add_argument(
8080
'-w',
8181
'--work-dir',

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
ARG PYTORCH="1.6.0"
22
ARG CUDA="10.1"
33
ARG CUDNN="7"
4-
ARG MMCV="1.3.12"
4+
ARG MMCV="1.3.13"
55

66
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
77

docker/serve/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ARG CUDA="10.1"
33
ARG CUDNN="7"
44
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
55

6-
ARG MMCV="1.3.12"
6+
ARG MMCV="1.3.13"
77
ARG MMSEG="0.17.0"
88

99
ENV PYTHONUNBUFFERED TRUE

docs/get_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
1111

1212
| MMSegmentation version | MMCV version |
1313
|:-------------------:|:-------------------:|
14-
| master | mmcv-full>=1.3.7, <1.4.0 |
14+
| master | mmcv-full>=1.3.13, <1.4.0 |
1515
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
1616
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
1717
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

docs/train.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ To trade speed with GPU memory, you may pass in `--options model.backbone.with_c
1919

2020
### Train with a single GPU
2121

22+
official support:
23+
24+
```shell
25+
./tools/dist_train.sh ${CONFIG_FILE} 1 [optional arguments]
26+
```
27+
28+
experimental support (Convert SyncBN to BN):
29+
2230
```shell
2331
python tools/train.py ${CONFIG_FILE} [optional arguments]
2432
```

docs_zh-CN/get_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
| MMSegmentation 版本 | MMCV 版本 |
1313
|:-------------------:|:-------------------:|
14-
| master | mmcv-full>=1.3.7, <1.4.0 |
14+
| master | mmcv-full>=1.3.13, <1.4.0 |
1515
| 0.17.0 | mmcv-full>=1.3.7, <1.4.0 |
1616
| 0.16.0 | mmcv-full>=1.3.7, <1.4.0 |
1717
| 0.15.0 | mmcv-full>=1.3.7, <1.4.0 |

mmseg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .version import __version__, version_info
88

9-
MMCV_MIN = '1.3.7'
9+
MMCV_MIN = '1.3.13'
1010
MMCV_MAX = '1.4.0'
1111

1212

tests/test_models/test_forward.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
import torch
1010
import torch.nn as nn
11-
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
11+
from mmcv.cnn.utils import revert_sync_batchnorm
1212

1313

1414
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
@@ -189,28 +189,6 @@ def _check_input_dim(self, inputs):
189189
pass
190190

191191

192-
def _convert_batchnorm(module):
193-
module_output = module
194-
if isinstance(module, SyncBatchNorm):
195-
# to be consistent with SyncBN, we hack dim check function in BN
196-
module_output = _BatchNorm(module.num_features, module.eps,
197-
module.momentum, module.affine,
198-
module.track_running_stats)
199-
if module.affine:
200-
module_output.weight.data = module.weight.data.clone().detach()
201-
module_output.bias.data = module.bias.data.clone().detach()
202-
# keep requires_grad unchanged
203-
module_output.weight.requires_grad = module.weight.requires_grad
204-
module_output.bias.requires_grad = module.bias.requires_grad
205-
module_output.running_mean = module.running_mean
206-
module_output.running_var = module.running_var
207-
module_output.num_batches_tracked = module.num_batches_tracked
208-
for name, child in module.named_children():
209-
module_output.add_module(name, _convert_batchnorm(child))
210-
del module
211-
return module_output
212-
213-
214192
@patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim',
215193
_check_input_dim)
216194
@patch('torch.distributed.get_world_size', get_world_size)
@@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file):
241219
imgs = imgs.cuda()
242220
gt_semantic_seg = gt_semantic_seg.cuda()
243221
else:
244-
segmentor = _convert_batchnorm(segmentor)
222+
segmentor = revert_sync_batchnorm(segmentor)
245223

246224
# Test forward train
247225
losses = segmentor.forward(

tools/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import os
55
import os.path as osp
66
import time
7+
import warnings
78

89
import mmcv
910
import torch
11+
from mmcv.cnn.utils import revert_sync_batchnorm
1012
from mmcv.runner import get_dist_info, init_dist
1113
from mmcv.utils import Config, DictAction, get_git_hash
1214

@@ -137,6 +139,14 @@ def main():
137139
test_cfg=cfg.get('test_cfg'))
138140
model.init_weights()
139141

142+
# SyncBN is not support for DP
143+
if not distributed:
144+
warnings.warn(
145+
'SyncBN is only supported with DDP. To be compatible with DP, '
146+
'we convert SyncBN to BN. Please use dist_train.sh which can '
147+
'avoid this error.')
148+
model = revert_sync_batchnorm(model)
149+
140150
logger.info(model)
141151

142152
datasets = [build_dataset(cfg.data.train)]

0 commit comments

Comments
 (0)