|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from __future__ import division |
| 3 | +from typing import Iterator, Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.utils.data import Dataset |
| 7 | +from torch.utils.data import DistributedSampler as _DistributedSampler |
| 8 | + |
| 9 | +from mmseg.core.utils import sync_random_seed |
| 10 | + |
| 11 | + |
| 12 | +class DistributedSampler(_DistributedSampler): |
| 13 | + """DistributedSampler inheriting from |
| 14 | + `torch.utils.data.DistributedSampler`. |
| 15 | +
|
| 16 | + Args: |
| 17 | + datasets (Dataset): the dataset will be loaded. |
| 18 | + num_replicas (int, optional): Number of processes participating in |
| 19 | + distributed training. By default, world_size is retrieved from the |
| 20 | + current distributed group. |
| 21 | + rank (int, optional): Rank of the current process within num_replicas. |
| 22 | + By default, rank is retrieved from the current distributed group. |
| 23 | + shuffle (bool): If True (default), sampler will shuffle the indices. |
| 24 | + seed (int): random seed used to shuffle the sampler if |
| 25 | + :attr:`shuffle=True`. This number should be identical across all |
| 26 | + processes in the distributed group. Default: ``0``. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, |
| 30 | + dataset: Dataset, |
| 31 | + num_replicas: Optional[int] = None, |
| 32 | + rank: Optional[int] = None, |
| 33 | + shuffle: bool = True, |
| 34 | + seed=0) -> None: |
| 35 | + super().__init__( |
| 36 | + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
| 37 | + |
| 38 | + # In distributed sampling, different ranks should sample |
| 39 | + # non-overlapped data in the dataset. Therefore, this function |
| 40 | + # is used to make sure that each rank shuffles the data indices |
| 41 | + # in the same order based on the same seed. Then different ranks |
| 42 | + # could use different indices to select non-overlapped data from the |
| 43 | + # same data list. |
| 44 | + self.seed = sync_random_seed(seed) |
| 45 | + |
| 46 | + def __iter__(self) -> Iterator: |
| 47 | + """ |
| 48 | + Yields: |
| 49 | + Iterator: iterator of indices for rank. |
| 50 | + """ |
| 51 | + # deterministically shuffle based on epoch |
| 52 | + if self.shuffle: |
| 53 | + g = torch.Generator() |
| 54 | + # When :attr:`shuffle=True`, this ensures all replicas |
| 55 | + # use a different random ordering for each epoch. |
| 56 | + # Otherwise, the next iteration of this sampler will |
| 57 | + # yield the same ordering. |
| 58 | + g.manual_seed(self.epoch + self.seed) |
| 59 | + indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| 60 | + else: |
| 61 | + indices = torch.arange(len(self.dataset)).tolist() |
| 62 | + |
| 63 | + # add extra samples to make it evenly divisible |
| 64 | + indices += indices[:(self.total_size - len(indices))] |
| 65 | + assert len(indices) == self.total_size |
| 66 | + |
| 67 | + # subsample |
| 68 | + indices = indices[self.rank:self.total_size:self.num_replicas] |
| 69 | + assert len(indices) == self.num_samples |
| 70 | + |
| 71 | + return iter(indices) |
0 commit comments