Skip to content

Commit 2bd4c84

Browse files
authored
Rewrite multigpu testing with mmddp (open-mmlab#622)
* rewrite multigpu testing with mmddp * fix an indent * update slurm testing script * update readme and scripts
1 parent 710b8e2 commit 2bd4c84

File tree

7 files changed

+191
-81
lines changed

7 files changed

+191
-81
lines changed

GETTING_STARTED.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,17 @@ and also some high-level apis for easier integration to other projects.
1414
- [x] multiple GPU testing
1515
- [x] visualize detection results
1616

17-
You can use the following command to test a dataset.
17+
You can use the following commands to test a dataset.
1818

1919
```shell
20-
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--gpus ${GPU_NUM}] [--proc_per_gpu ${PROC_NUM}] [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
21-
```
20+
# single-gpu testing
21+
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
2222

23-
Positional arguments:
24-
- `CONFIG_FILE`: Path to the config file of the corresponding model.
25-
- `CHECKPOINT_FILE`: Path to the checkpoint file.
23+
# multi-gpu testing
24+
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
25+
```
2626

2727
Optional arguments:
28-
- `GPU_NUM`: Number of GPUs used for testing. (default: 1)
29-
- `PROC_NUM`: Number of processes on each GPU. (default: 1)
3028
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
3129
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values are: `proposal_fast`, `proposal`, `bbox`, `segm`, `keypoints`.
3230
- `--show`: If specified, detection results will be ploted on the images and shown in a new window. Only applicable for single GPU testing.
@@ -51,12 +49,12 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
5149
--out results.pkl --eval bbox mask
5250
```
5351

54-
3. Test Mask R-CNN with 8 GPUs and 2 processes per GPU, and evaluate the bbox and mask AP.
52+
3. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.
5553

5654
```shell
57-
python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
55+
./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x.py \
5856
checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth \
59-
--gpus 8 --proc_per_gpu 2 --out results.pkl --eval bbox mask
57+
8 --out results.pkl --eval bbox mask
6058
```
6159

6260
### High-level APIs for testing images.

mmdet/datasets/loader/build_loader.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mmcv.parallel import collate
55
from torch.utils.data import DataLoader
66

7-
from .sampler import GroupSampler, DistributedGroupSampler
7+
from .sampler import GroupSampler, DistributedGroupSampler, DistributedSampler
88

99
# https://github.com/pytorch/pytorch/issues/973
1010
import resource
@@ -18,27 +18,31 @@ def build_dataloader(dataset,
1818
num_gpus=1,
1919
dist=True,
2020
**kwargs):
21+
shuffle = kwargs.get('shuffle', True)
2122
if dist:
2223
rank, world_size = get_dist_info()
23-
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
24-
rank)
24+
if shuffle:
25+
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
26+
world_size, rank)
27+
else:
28+
sampler = DistributedSampler(dataset,
29+
world_size,
30+
rank,
31+
shuffle=False)
2532
batch_size = imgs_per_gpu
2633
num_workers = workers_per_gpu
2734
else:
28-
if not kwargs.get('shuffle', True):
29-
sampler = None
30-
else:
31-
sampler = GroupSampler(dataset, imgs_per_gpu)
35+
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
3236
batch_size = num_gpus * imgs_per_gpu
3337
num_workers = num_gpus * workers_per_gpu
3438

35-
data_loader = DataLoader(
36-
dataset,
37-
batch_size=batch_size,
38-
sampler=sampler,
39-
num_workers=num_workers,
40-
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
41-
pin_memory=False,
42-
**kwargs)
39+
data_loader = DataLoader(dataset,
40+
batch_size=batch_size,
41+
sampler=sampler,
42+
num_workers=num_workers,
43+
collate_fn=partial(collate,
44+
samples_per_gpu=imgs_per_gpu),
45+
pin_memory=False,
46+
**kwargs)
4347

4448
return data_loader

mmdet/datasets/loader/sampler.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,34 @@
55
import numpy as np
66

77
from torch.distributed import get_world_size, get_rank
8-
from torch.utils.data.sampler import Sampler
8+
from torch.utils.data import Sampler
9+
from torch.utils.data import DistributedSampler as _DistributedSampler
10+
11+
12+
class DistributedSampler(_DistributedSampler):
13+
14+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
15+
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
16+
self.shuffle = shuffle
17+
18+
def __iter__(self):
19+
# deterministically shuffle based on epoch
20+
if self.shuffle:
21+
g = torch.Generator()
22+
g.manual_seed(self.epoch)
23+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
24+
else:
25+
indices = torch.arange(len(self.dataset)).tolist()
26+
27+
# add extra samples to make it evenly divisible
28+
indices += indices[:(self.total_size - len(indices))]
29+
assert len(indices) == self.total_size
30+
31+
# subsample
32+
indices = indices[self.rank:self.total_size:self.num_replicas]
33+
assert len(indices) == self.num_samples
34+
35+
return iter(indices)
936

1037

1138
class GroupSampler(Sampler):
@@ -112,8 +139,8 @@ def __iter__(self):
112139

113140
indices = [
114141
indices[j] for i in list(
115-
torch.randperm(
116-
len(indices) // self.samples_per_gpu, generator=g))
142+
torch.randperm(len(indices) // self.samples_per_gpu,
143+
generator=g))
117144
for j in range(i * self.samples_per_gpu, (i + 1) *
118145
self.samples_per_gpu)
119146
]

tools/dist_test.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
3+
PYTHON=${PYTHON:-"python"}
4+
5+
CONFIG=$1
6+
CHECKPOINT=$2
7+
GPUS=$3
8+
9+
$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \
10+
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}

tools/dist_train.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@
22

33
PYTHON=${PYTHON:-"python"}
44

5-
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
5+
CONFIG=$1
6+
GPUS=$2
7+
8+
$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \
9+
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}

tools/slurm_test.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@ JOB_NAME=$2
77
CONFIG=$3
88
CHECKPOINT=$4
99
GPUS=${GPUS:-8}
10-
CPUS_PER_TASK=${CPUS_PER_TASK:-32}
10+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11+
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
1112
PY_ARGS=${@:5}
1213
SRUN_ARGS=${SRUN_ARGS:-""}
1314

1415
srun -p ${PARTITION} \
1516
--job-name=${JOB_NAME} \
16-
--gres=gpu:${GPUS} \
17-
--ntasks=1 \
18-
--ntasks-per-node=1 \
17+
--gres=gpu:${GPUS_PER_NODE} \
18+
--ntasks=${GPUS} \
19+
--ntasks-per-node=${GPUS_PER_NODE} \
1920
--cpus-per-task=${CPUS_PER_TASK} \
2021
--kill-on-bad-exit=1 \
2122
${SRUN_ARGS} \
22-
python tools/test.py ${CONFIG} ${CHECKPOINT} --gpus ${GPUS} ${PY_ARGS}
23+
python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}

tools/test.py

Lines changed: 112 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import argparse
2+
import os.path as osp
3+
import shutil
4+
import tempfile
25

3-
import torch
46
import mmcv
5-
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
6-
from mmcv.parallel import scatter, collate, MMDataParallel
7+
import torch
8+
import torch.distributed as dist
9+
from mmcv.runner import load_checkpoint, get_dist_info
10+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
711

8-
from mmdet import datasets
12+
from mmdet.apis import init_dist
913
from mmdet.core import results2json, coco_eval
10-
from mmdet.datasets import build_dataloader
11-
from mmdet.models import build_detector, detectors
14+
from mmdet.datasets import build_dataloader, get_dataset
15+
from mmdet.models import build_detector
1216

1317

14-
def single_test(model, data_loader, show=False):
18+
def single_gpu_test(model, data_loader, show=False):
1519
model.eval()
1620
results = []
1721
dataset = data_loader.dataset
@@ -22,7 +26,9 @@ def single_test(model, data_loader, show=False):
2226
results.append(result)
2327

2428
if show:
25-
model.module.show_result(data, result, dataset.img_norm_cfg,
29+
model.module.show_result(data,
30+
result,
31+
dataset.img_norm_cfg,
2632
dataset=dataset.CLASSES)
2733

2834
batch_size = data['img'][0].size(0)
@@ -31,22 +37,76 @@ def single_test(model, data_loader, show=False):
3137
return results
3238

3339

34-
def _data_func(data, device_id):
35-
data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
36-
return dict(return_loss=False, rescale=True, **data)
40+
def multi_gpu_test(model, data_loader, tmpdir=None):
41+
model.eval()
42+
results = []
43+
dataset = data_loader.dataset
44+
rank, world_size = get_dist_info()
45+
if rank == 0:
46+
prog_bar = mmcv.ProgressBar(len(dataset))
47+
for i, data in enumerate(data_loader):
48+
with torch.no_grad():
49+
result = model(return_loss=False, rescale=True, **data)
50+
results.append(result)
51+
52+
if rank == 0:
53+
batch_size = data['img'][0].size(0)
54+
for _ in range(batch_size * world_size):
55+
prog_bar.update()
56+
57+
# collect results from all ranks
58+
results = collect_results(results, len(dataset), tmpdir)
59+
60+
return results
61+
62+
63+
def collect_results(result_part, size, tmpdir=None):
64+
rank, world_size = get_dist_info()
65+
# create a tmp dir if it is not specified
66+
if tmpdir is None:
67+
MAX_LEN = 512
68+
# 32 is whitespace
69+
dir_tensor = torch.full((MAX_LEN, ),
70+
32,
71+
dtype=torch.uint8,
72+
device='cuda')
73+
if rank == 0:
74+
tmpdir = tempfile.mkdtemp()
75+
tmpdir = torch.tensor(bytearray(tmpdir.encode()),
76+
dtype=torch.uint8,
77+
device='cuda')
78+
dir_tensor[:len(tmpdir)] = tmpdir
79+
dist.broadcast(dir_tensor, 0)
80+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
81+
else:
82+
mmcv.mkdir_or_exist(tmpdir)
83+
# dump the part result to the dir
84+
mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
85+
dist.barrier()
86+
# collect all parts
87+
if rank != 0:
88+
return None
89+
else:
90+
# load results of all parts from tmp dir
91+
part_list = []
92+
for i in range(world_size):
93+
part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
94+
part_list.append(mmcv.load(part_file))
95+
# sort the results
96+
ordered_results = []
97+
for res in zip(*part_list):
98+
ordered_results.extend(list(res))
99+
# the dataloader may pad some samples
100+
ordered_results = ordered_results[:size]
101+
# remove tmp dir
102+
shutil.rmtree(tmpdir)
103+
return ordered_results
37104

38105

39106
def parse_args():
40107
parser = argparse.ArgumentParser(description='MMDet test detector')
41108
parser.add_argument('config', help='test config file path')
42109
parser.add_argument('checkpoint', help='checkpoint file')
43-
parser.add_argument(
44-
'--gpus', default=1, type=int, help='GPU number used for testing')
45-
parser.add_argument(
46-
'--proc_per_gpu',
47-
default=1,
48-
type=int,
49-
help='Number of processes per GPU')
50110
parser.add_argument('--out', help='output result file')
51111
parser.add_argument(
52112
'--eval',
@@ -55,6 +115,12 @@ def parse_args():
55115
choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
56116
help='eval types')
57117
parser.add_argument('--show', action='store_true', help='show results')
118+
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
119+
parser.add_argument('--launcher',
120+
choices=['none', 'pytorch', 'slurm', 'mpi'],
121+
default='none',
122+
help='job launcher')
123+
parser.add_argument('--local_rank', type=int, default=0)
58124
args = parser.parse_args()
59125
return args
60126

@@ -72,36 +138,36 @@ def main():
72138
cfg.model.pretrained = None
73139
cfg.data.test.test_mode = True
74140

75-
dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
76-
if args.gpus == 1:
77-
model = build_detector(
78-
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
79-
load_checkpoint(model, args.checkpoint)
141+
# init distributed env first, since logger depends on the dist info.
142+
if args.launcher == 'none':
143+
distributed = False
144+
else:
145+
distributed = True
146+
init_dist(args.launcher, **cfg.dist_params)
147+
148+
# build the dataloader
149+
# TODO: support multiple images per gpu (only minor changes are needed)
150+
dataset = get_dataset(cfg.data.test)
151+
data_loader = build_dataloader(dataset,
152+
imgs_per_gpu=1,
153+
workers_per_gpu=cfg.data.workers_per_gpu,
154+
dist=distributed,
155+
shuffle=False)
156+
157+
# build the model and load checkpoint
158+
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
159+
load_checkpoint(model, args.checkpoint, map_location='cpu')
160+
161+
if not distributed:
80162
model = MMDataParallel(model, device_ids=[0])
81-
82-
data_loader = build_dataloader(
83-
dataset,
84-
imgs_per_gpu=1,
85-
workers_per_gpu=cfg.data.workers_per_gpu,
86-
num_gpus=1,
87-
dist=False,
88-
shuffle=False)
89-
outputs = single_test(model, data_loader, args.show)
163+
outputs = single_gpu_test(model, data_loader, args.show)
90164
else:
91-
model_args = cfg.model.copy()
92-
model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
93-
model_type = getattr(detectors, model_args.pop('type'))
94-
outputs = parallel_test(
95-
model_type,
96-
model_args,
97-
args.checkpoint,
98-
dataset,
99-
_data_func,
100-
range(args.gpus),
101-
workers_per_gpu=args.proc_per_gpu)
102-
103-
if args.out:
104-
print('writing results to {}'.format(args.out))
165+
model = MMDistributedDataParallel(model.cuda())
166+
outputs = multi_gpu_test(model, data_loader, args.tmpdir)
167+
168+
rank, _ = get_dist_info()
169+
if args.out and rank == 0:
170+
print('\nwriting results to {}'.format(args.out))
105171
mmcv.dump(outputs, args.out)
106172
eval_types = args.eval
107173
if eval_types:

0 commit comments

Comments
 (0)