Skip to content

Commit f3ae234

Browse files
authored
[Enchance] Support random seed for distributed sampler (open-mmlab#1411)
* support random seed for distributed sampler * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * move mmseg/utils/dist_util.py to mmseg/core/utils/dist_util.py * change dist sampler * change dist sampler * fix docstring in sync_random_seed
1 parent 1a33d50 commit f3ae234

File tree

5 files changed

+130
-3
lines changed

5 files changed

+130
-3
lines changed

mmseg/core/utils/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .dist_util import check_dist_init, sync_random_seed
23
from .layer_decay_optimizer_constructor import \
34
LearningRateDecayOptimizerConstructor
45
from .misc import add_prefix
56

6-
__all__ = ['add_prefix', 'LearningRateDecayOptimizerConstructor']
7+
__all__ = [
8+
'add_prefix', 'LearningRateDecayOptimizerConstructor', 'check_dist_init',
9+
'sync_random_seed'
10+
]

mmseg/core/utils/dist_util.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import numpy as np
3+
import torch
4+
import torch.distributed as dist
5+
from mmcv.runner import get_dist_info
6+
7+
8+
def check_dist_init():
9+
return dist.is_available() and dist.is_initialized()
10+
11+
12+
def sync_random_seed(seed=None, device='cuda'):
13+
"""Make sure different ranks share the same seed. All workers must call
14+
this function, otherwise it will deadlock. This method is generally used in
15+
`DistributedSampler`, because the seed should be identical across all
16+
processes in the distributed group.
17+
18+
In distributed sampling, different ranks should sample non-overlapped
19+
data in the dataset. Therefore, this function is used to make sure that
20+
each rank shuffles the data indices in the same order based
21+
on the same seed. Then different ranks could use different indices
22+
to select non-overlapped data from the same data list.
23+
24+
Args:
25+
seed (int, Optional): The seed. Default to None.
26+
device (str): The device where the seed will be put on.
27+
Default to 'cuda'.
28+
Returns:
29+
int: Seed to be used.
30+
"""
31+
32+
if seed is None:
33+
seed = np.random.randint(2**31)
34+
assert isinstance(seed, int)
35+
36+
rank, world_size = get_dist_info()
37+
38+
if world_size == 1:
39+
return seed
40+
41+
if rank == 0:
42+
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
43+
else:
44+
random_num = torch.tensor(0, dtype=torch.int32, device=device)
45+
dist.broadcast(random_num, src=0)
46+
return random_num.item()

mmseg/datasets/builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from mmcv.parallel import collate
1010
from mmcv.runner import get_dist_info
1111
from mmcv.utils import Registry, build_from_cfg, digit_version
12-
from torch.utils.data import DataLoader, DistributedSampler
12+
from torch.utils.data import DataLoader
13+
14+
from .samplers import DistributedSampler
1315

1416
if platform.system() != 'Windows':
1517
# https://github.com/pytorch/pytorch/issues/973
@@ -129,7 +131,7 @@ def build_dataloader(dataset,
129131
rank, world_size = get_dist_info()
130132
if dist:
131133
sampler = DistributedSampler(
132-
dataset, world_size, rank, shuffle=shuffle)
134+
dataset, world_size, rank, shuffle=shuffle, seed=seed)
133135
shuffle = False
134136
batch_size = samples_per_gpu
135137
num_workers = workers_per_gpu
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .distributed_sampler import DistributedSampler
3+
4+
__all__ = ['DistributedSampler']
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from __future__ import division
3+
from typing import Iterator, Optional
4+
5+
import torch
6+
from torch.utils.data import Dataset
7+
from torch.utils.data import DistributedSampler as _DistributedSampler
8+
9+
from mmseg.core.utils import sync_random_seed
10+
11+
12+
class DistributedSampler(_DistributedSampler):
13+
"""DistributedSampler inheriting from
14+
`torch.utils.data.DistributedSampler`.
15+
16+
Args:
17+
datasets (Dataset): the dataset will be loaded.
18+
num_replicas (int, optional): Number of processes participating in
19+
distributed training. By default, world_size is retrieved from the
20+
current distributed group.
21+
rank (int, optional): Rank of the current process within num_replicas.
22+
By default, rank is retrieved from the current distributed group.
23+
shuffle (bool): If True (default), sampler will shuffle the indices.
24+
seed (int): random seed used to shuffle the sampler if
25+
:attr:`shuffle=True`. This number should be identical across all
26+
processes in the distributed group. Default: ``0``.
27+
"""
28+
29+
def __init__(self,
30+
dataset: Dataset,
31+
num_replicas: Optional[int] = None,
32+
rank: Optional[int] = None,
33+
shuffle: bool = True,
34+
seed=0) -> None:
35+
super().__init__(
36+
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
37+
38+
# In distributed sampling, different ranks should sample
39+
# non-overlapped data in the dataset. Therefore, this function
40+
# is used to make sure that each rank shuffles the data indices
41+
# in the same order based on the same seed. Then different ranks
42+
# could use different indices to select non-overlapped data from the
43+
# same data list.
44+
self.seed = sync_random_seed(seed)
45+
46+
def __iter__(self) -> Iterator:
47+
"""
48+
Yields:
49+
Iterator: iterator of indices for rank.
50+
"""
51+
# deterministically shuffle based on epoch
52+
if self.shuffle:
53+
g = torch.Generator()
54+
# When :attr:`shuffle=True`, this ensures all replicas
55+
# use a different random ordering for each epoch.
56+
# Otherwise, the next iteration of this sampler will
57+
# yield the same ordering.
58+
g.manual_seed(self.epoch + self.seed)
59+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
60+
else:
61+
indices = torch.arange(len(self.dataset)).tolist()
62+
63+
# add extra samples to make it evenly divisible
64+
indices += indices[:(self.total_size - len(indices))]
65+
assert len(indices) == self.total_size
66+
67+
# subsample
68+
indices = indices[self.rank:self.total_size:self.num_replicas]
69+
assert len(indices) == self.num_samples
70+
71+
return iter(indices)

0 commit comments

Comments
 (0)