Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@ and also some high-level apis for easier integration to other projects.
- [x] multiple GPU testing
- [x] visualize detection results

You can use the following command to test a dataset.
You can use the following commands to test a dataset.

```shell
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--gpus ${GPU_NUM}] [--proc_per_gpu ${PROC_NUM}] [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
```
# single-gpu testing
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

Positional arguments:
- `CONFIG_FILE`: Path to the config file of the corresponding model.
- `CHECKPOINT_FILE`: Path to the checkpoint file.
# multi-gpu testing
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
```

Optional arguments:
- `GPU_NUM`: Number of GPUs used for testing. (default: 1)
- `PROC_NUM`: Number of processes on each GPU. (default: 1)
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values are: `proposal_fast`, `proposal`, `bbox`, `segm`, `keypoints`.
- `--show`: If specified, detection results will be ploted on the images and shown in a new window. Only applicable for single GPU testing.
Expand All @@ -51,12 +49,12 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
--out results.pkl --eval bbox mask
```

3. Test Mask R-CNN with 8 GPUs and 2 processes per GPU, and evaluate the bbox and mask AP.
3. Test Mask R-CNN with 8 GPUs, and evaluate the bbox and mask AP.

```shell
python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
./tools/dist_test.sh configs/mask_rcnn_r50_fpn_1x.py \
checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth \
--gpus 8 --proc_per_gpu 2 --out results.pkl --eval bbox mask
8 --out results.pkl --eval bbox mask
```

### High-level APIs for testing images.
Expand Down
34 changes: 19 additions & 15 deletions mmdet/datasets/loader/build_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mmcv.parallel import collate
from torch.utils.data import DataLoader

from .sampler import GroupSampler, DistributedGroupSampler
from .sampler import GroupSampler, DistributedGroupSampler, DistributedSampler

# https://github.com/pytorch/pytorch/issues/973
import resource
Expand All @@ -18,27 +18,31 @@ def build_dataloader(dataset,
num_gpus=1,
dist=True,
**kwargs):
shuffle = kwargs.get('shuffle', True)
if dist:
rank, world_size = get_dist_info()
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
rank)
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
else:
sampler = DistributedSampler(dataset,
world_size,
rank,
shuffle=False)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
if not kwargs.get('shuffle', True):
sampler = None
else:
sampler = GroupSampler(dataset, imgs_per_gpu)
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu

data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)
data_loader = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate,
samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)

return data_loader
33 changes: 30 additions & 3 deletions mmdet/datasets/loader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,34 @@
import numpy as np

from torch.distributed import get_world_size, get_rank
from torch.utils.data.sampler import Sampler
from torch.utils.data import Sampler
from torch.utils.data import DistributedSampler as _DistributedSampler


class DistributedSampler(_DistributedSampler):

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle

def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()

# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples

return iter(indices)


class GroupSampler(Sampler):
Expand Down Expand Up @@ -112,8 +139,8 @@ def __iter__(self):

indices = [
indices[j] for i in list(
torch.randperm(
len(indices) // self.samples_per_gpu, generator=g))
torch.randperm(len(indices) // self.samples_per_gpu,
generator=g))
for j in range(i * self.samples_per_gpu, (i + 1) *
self.samples_per_gpu)
]
Expand Down
10 changes: 10 additions & 0 deletions tools/dist_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env bash

PYTHON=${PYTHON:-"python"}

CONFIG=$1
CHECKPOINT=$2
GPUS=$3

$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
6 changes: 5 additions & 1 deletion tools/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

PYTHON=${PYTHON:-"python"}

$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
CONFIG=$1
GPUS=$2

$PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
11 changes: 6 additions & 5 deletions tools/slurm_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ JOB_NAME=$2
CONFIG=$3
CHECKPOINT=$4
GPUS=${GPUS:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-32}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}

srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS} \
--ntasks=1 \
--ntasks-per-node=1 \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python tools/test.py ${CONFIG} ${CHECKPOINT} --gpus ${GPUS} ${PY_ARGS}
python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
158 changes: 112 additions & 46 deletions tools/test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import argparse
import os.path as osp
import shutil
import tempfile

import torch
import mmcv
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from mmcv.parallel import scatter, collate, MMDataParallel
import torch
import torch.distributed as dist
from mmcv.runner import load_checkpoint, get_dist_info
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmdet import datasets
from mmdet.apis import init_dist
from mmdet.core import results2json, coco_eval
from mmdet.datasets import build_dataloader
from mmdet.models import build_detector, detectors
from mmdet.datasets import build_dataloader, get_dataset
from mmdet.models import build_detector


def single_test(model, data_loader, show=False):
def single_gpu_test(model, data_loader, show=False):
model.eval()
results = []
dataset = data_loader.dataset
Expand All @@ -22,7 +26,9 @@ def single_test(model, data_loader, show=False):
results.append(result)

if show:
model.module.show_result(data, result, dataset.img_norm_cfg,
model.module.show_result(data,
result,
dataset.img_norm_cfg,
dataset=dataset.CLASSES)

batch_size = data['img'][0].size(0)
Expand All @@ -31,22 +37,76 @@ def single_test(model, data_loader, show=False):
return results


def _data_func(data, device_id):
data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
return dict(return_loss=False, rescale=True, **data)
def multi_gpu_test(model, data_loader, tmpdir=None):
model.eval()
results = []
dataset = data_loader.dataset
rank, world_size = get_dist_info()
if rank == 0:
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
results.append(result)

if rank == 0:
batch_size = data['img'][0].size(0)
for _ in range(batch_size * world_size):
prog_bar.update()

# collect results from all ranks
results = collect_results(results, len(dataset), tmpdir)

return results


def collect_results(result_part, size, tmpdir=None):
rank, world_size = get_dist_info()
# create a tmp dir if it is not specified
if tmpdir is None:
MAX_LEN = 512
# 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ),
32,
dtype=torch.uint8,
device='cuda')
if rank == 0:
tmpdir = tempfile.mkdtemp()
tmpdir = torch.tensor(bytearray(tmpdir.encode()),
dtype=torch.uint8,
device='cuda')
dir_tensor[:len(tmpdir)] = tmpdir
dist.broadcast(dir_tensor, 0)
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
else:
mmcv.mkdir_or_exist(tmpdir)
# dump the part result to the dir
mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
dist.barrier()
# collect all parts
if rank != 0:
return None
else:
# load results of all parts from tmp dir
part_list = []
for i in range(world_size):
part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
part_list.append(mmcv.load(part_file))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
shutil.rmtree(tmpdir)
return ordered_results


def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--gpus', default=1, type=int, help='GPU number used for testing')
parser.add_argument(
'--proc_per_gpu',
default=1,
type=int,
help='Number of processes per GPU')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--eval',
Expand All @@ -55,6 +115,12 @@ def parse_args():
choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument('--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
return args

Expand All @@ -72,36 +138,36 @@ def main():
cfg.model.pretrained = None
cfg.data.test.test_mode = True

dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
if args.gpus == 1:
model = build_detector(
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
load_checkpoint(model, args.checkpoint)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)

# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = get_dataset(cfg.data.test)
data_loader = build_dataloader(dataset,
imgs_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)

# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
load_checkpoint(model, args.checkpoint, map_location='cpu')

if not distributed:
model = MMDataParallel(model, device_ids=[0])

data_loader = build_dataloader(
dataset,
imgs_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
num_gpus=1,
dist=False,
shuffle=False)
outputs = single_test(model, data_loader, args.show)
outputs = single_gpu_test(model, data_loader, args.show)
else:
model_args = cfg.model.copy()
model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
model_type = getattr(detectors, model_args.pop('type'))
outputs = parallel_test(
model_type,
model_args,
args.checkpoint,
dataset,
_data_func,
range(args.gpus),
workers_per_gpu=args.proc_per_gpu)

if args.out:
print('writing results to {}'.format(args.out))
model = MMDistributedDataParallel(model.cuda())
outputs = multi_gpu_test(model, data_loader, args.tmpdir)

rank, _ = get_dist_info()
if args.out and rank == 0:
print('\nwriting results to {}'.format(args.out))
mmcv.dump(outputs, args.out)
eval_types = args.eval
if eval_types:
Expand Down