Skip to content

Commit 8bf38df

Browse files
zhangtemplarhellock
authored andcommitted
Only import torch.distributed when needed (open-mmlab#882)
* Fix an import error for `get_world_size` and `get_rank` * Only import torch.distributed when needed torch.distributed is only used in DistributedGroupSampler * use `get_dist_info` to obtain world size and rank `get_dist_info` from `mmcv.runner.utils` handles the problem of `distributed_c10d` doesn't exist.
1 parent f080ccb commit 8bf38df

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

mmdet/datasets/loader/sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import numpy as np
66

7-
from torch.distributed import get_world_size, get_rank
7+
from mmcv.runner.utils import get_dist_info
88
from torch.utils.data import Sampler
99
from torch.utils.data import DistributedSampler as _DistributedSampler
1010

@@ -95,10 +95,11 @@ def __init__(self,
9595
samples_per_gpu=1,
9696
num_replicas=None,
9797
rank=None):
98+
_rank, _num_replicas = get_dist_info()
9899
if num_replicas is None:
99-
num_replicas = get_world_size()
100+
num_replicas = _num_replicas
100101
if rank is None:
101-
rank = get_rank()
102+
rank = _rank
102103
self.dataset = dataset
103104
self.samples_per_gpu = samples_per_gpu
104105
self.num_replicas = num_replicas

0 commit comments

Comments
 (0)