|
8 | 8 | import pytest |
9 | 9 | import torch |
10 | 10 | import torch.nn as nn |
11 | | -from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm |
| 11 | +from mmcv.cnn.utils import revert_sync_batchnorm |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10): |
@@ -189,28 +189,6 @@ def _check_input_dim(self, inputs): |
189 | 189 | pass |
190 | 190 |
|
191 | 191 |
|
192 | | -def _convert_batchnorm(module): |
193 | | - module_output = module |
194 | | - if isinstance(module, SyncBatchNorm): |
195 | | - # to be consistent with SyncBN, we hack dim check function in BN |
196 | | - module_output = _BatchNorm(module.num_features, module.eps, |
197 | | - module.momentum, module.affine, |
198 | | - module.track_running_stats) |
199 | | - if module.affine: |
200 | | - module_output.weight.data = module.weight.data.clone().detach() |
201 | | - module_output.bias.data = module.bias.data.clone().detach() |
202 | | - # keep requires_grad unchanged |
203 | | - module_output.weight.requires_grad = module.weight.requires_grad |
204 | | - module_output.bias.requires_grad = module.bias.requires_grad |
205 | | - module_output.running_mean = module.running_mean |
206 | | - module_output.running_var = module.running_var |
207 | | - module_output.num_batches_tracked = module.num_batches_tracked |
208 | | - for name, child in module.named_children(): |
209 | | - module_output.add_module(name, _convert_batchnorm(child)) |
210 | | - del module |
211 | | - return module_output |
212 | | - |
213 | | - |
214 | 192 | @patch('torch.nn.modules.batchnorm._BatchNorm._check_input_dim', |
215 | 193 | _check_input_dim) |
216 | 194 | @patch('torch.distributed.get_world_size', get_world_size) |
@@ -241,7 +219,7 @@ def _test_encoder_decoder_forward(cfg_file): |
241 | 219 | imgs = imgs.cuda() |
242 | 220 | gt_semantic_seg = gt_semantic_seg.cuda() |
243 | 221 | else: |
244 | | - segmentor = _convert_batchnorm(segmentor) |
| 222 | + segmentor = revert_sync_batchnorm(segmentor) |
245 | 223 |
|
246 | 224 | # Test forward train |
247 | 225 | losses = segmentor.forward( |
|
0 commit comments