Skip to content

Commit 7dba7a2

Browse files
authored
[Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn (open-mmlab#1362)
1 parent b028824 commit 7dba7a2

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

mmseg/datasets/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
186186
worker_seed = num_workers * rank + worker_id + seed
187187
np.random.seed(worker_seed)
188188
random.seed(worker_seed)
189+
torch.manual_seed(worker_seed)

tools/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import mmcv
1010
import torch
11+
import torch.distributed as dist
1112
from mmcv.cnn.utils import revert_sync_batchnorm
1213
from mmcv.runner import get_dist_info, init_dist
1314
from mmcv.utils import Config, DictAction, get_git_hash
@@ -50,6 +51,10 @@ def parse_args():
5051
help='id of gpu to use '
5152
'(only applicable to non-distributed training)')
5253
parser.add_argument('--seed', type=int, default=None, help='random seed')
54+
parser.add_argument(
55+
'--diff_seed',
56+
action='store_true',
57+
help='Whether or not set different seeds for different ranks')
5358
parser.add_argument(
5459
'--deterministic',
5560
action='store_true',
@@ -180,6 +185,7 @@ def main():
180185

181186
# set random seeds
182187
seed = init_random_seed(args.seed)
188+
seed = seed + dist.get_rank() if args.diff_seed else seed
183189
logger.info(f'Set random seed to {seed}, '
184190
f'deterministic: {args.deterministic}')
185191
set_random_seed(seed, deterministic=args.deterministic)

0 commit comments

Comments
 (0)