Skip to content

Commit f8ed148

Browse files
authored
[Fix] Fix dist training infinite waiting issue (open-mmlab#1035)
* [open-mmlab#1034] fix dist training infinite waiting issue * print log_vars keys in assertion msg * linting issue
1 parent a357419 commit f8ed148

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

mmseg/models/segmentors/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,17 @@ def _parse_losses(losses):
188188
loss = sum(_value for _key, _value in log_vars.items()
189189
if 'loss' in _key)
190190

191+
# If the loss_vars has different length, raise assertion error
192+
# to prevent GPUs from infinite waiting.
193+
if dist.is_available() and dist.is_initialized():
194+
log_var_length = torch.tensor(len(log_vars), device=loss.device)
195+
dist.all_reduce(log_var_length)
196+
message = (f'rank {dist.get_rank()}' +
197+
f' len(log_vars): {len(log_vars)}' + ' keys: ' +
198+
','.join(log_vars.keys()) + '\n')
199+
assert log_var_length == len(log_vars) * dist.get_world_size(), \
200+
'loss log variables are different across GPUs!\n' + message
201+
191202
log_vars['loss'] = loss
192203
for loss_name, loss_value in log_vars.items():
193204
# reduce loss when distributed training

0 commit comments

Comments
 (0)