|
8 | 8 |
|
9 | 9 | import mmcv |
10 | 10 | import torch |
| 11 | +from mmcv.cnn.utils import revert_sync_batchnorm |
11 | 12 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
12 | 13 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, |
13 | 14 | wrap_fp16_model) |
14 | 15 | from mmcv.utils import DictAction |
15 | 16 |
|
| 17 | +from mmseg import digit_version |
16 | 18 | from mmseg.apis import multi_gpu_test, single_gpu_test |
17 | 19 | from mmseg.datasets import build_dataloader, build_dataset |
18 | 20 | from mmseg.models import build_segmentor |
@@ -147,11 +149,18 @@ def main(): |
147 | 149 | cfg.model.pretrained = None |
148 | 150 | cfg.data.test.test_mode = True |
149 | 151 |
|
150 | | - cfg.gpu_ids = [args.gpu_id] |
| 152 | + if args.gpu_id is not None: |
| 153 | + cfg.gpu_ids = [args.gpu_id] |
151 | 154 |
|
152 | 155 | # init distributed env first, since logger depends on the dist info. |
153 | 156 | if args.launcher == 'none': |
| 157 | + cfg.gpu_ids = [args.gpu_id] |
154 | 158 | distributed = False |
| 159 | + if len(cfg.gpu_ids) > 1: |
| 160 | + warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to ' |
| 161 | + f'{cfg.gpu_ids[0:1]} to avoid potential error in ' |
| 162 | + 'non-distribute testing time.') |
| 163 | + cfg.gpu_ids = cfg.gpu_ids[0:1] |
155 | 164 | else: |
156 | 165 | distributed = True |
157 | 166 | init_dist(args.launcher, **cfg.dist_params) |
@@ -236,7 +245,15 @@ def main(): |
236 | 245 | tmpdir = None |
237 | 246 |
|
238 | 247 | if not distributed: |
239 | | - model = MMDataParallel(model, device_ids=[0]) |
| 248 | + warnings.warn( |
| 249 | + 'SyncBN is only supported with DDP. To be compatible with DP, ' |
| 250 | + 'we convert SyncBN to BN. Please use dist_train.sh which can ' |
| 251 | + 'avoid this error.') |
| 252 | + if not torch.cuda.is_available(): |
| 253 | + assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ |
| 254 | + 'Please use MMCV >= 1.4.4 for CPU training!' |
| 255 | + model = revert_sync_batchnorm(model) |
| 256 | + model = MMDataParallel(model, device_ids=cfg.gpu_ids) |
240 | 257 | results = single_gpu_test( |
241 | 258 | model, |
242 | 259 | data_loader, |
|
0 commit comments