File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -188,6 +188,17 @@ def _parse_losses(losses):
188188 loss = sum (_value for _key , _value in log_vars .items ()
189189 if 'loss' in _key )
190190
191+ # If the loss_vars has different length, raise assertion error
192+ # to prevent GPUs from infinite waiting.
193+ if dist .is_available () and dist .is_initialized ():
194+ log_var_length = torch .tensor (len (log_vars ), device = loss .device )
195+ dist .all_reduce (log_var_length )
196+ message = (f'rank { dist .get_rank ()} ' +
197+ f' len(log_vars): { len (log_vars )} ' + ' keys: ' +
198+ ',' .join (log_vars .keys ()) + '\n ' )
199+ assert log_var_length == len (log_vars ) * dist .get_world_size (), \
200+ 'loss log variables are different across GPUs!\n ' + message
201+
191202 log_vars ['loss' ] = loss
192203 for loss_name , loss_value in log_vars .items ():
193204 # reduce loss when distributed training
You can’t perform that action at this time.
0 commit comments