@@ -179,15 +179,8 @@ def get_loss(self, head_outs, gt_meta):
179179 num_level_anchors )
180180 num_total_pos = sum (pos_num_l )
181181 try :
182- cloned_num_total_pos = num_total_pos .clone ()
183- reduced_cloned_num_total_pos = paddle .distributed .all_reduce (
184- cloned_num_total_pos )
185- if reduced_cloned_num_total_pos is not None :
186- num_total_pos = reduced_cloned_num_total_pos / paddle .distributed .get_world_size (
187- )
188- else :
189- num_total_pos = cloned_num_total_pos / paddle .distributed .get_world_size (
190- )
182+ paddle .distributed .all_reduce (num_total_pos )
183+ num_total_pos = num_total_pos / paddle .distributed .get_world_size ()
191184 except :
192185 num_total_pos = max (num_total_pos , 1 )
193186
@@ -262,12 +255,7 @@ def get_loss(self, head_outs, gt_meta):
262255
263256 avg_factor = sum (avg_factor )
264257 try :
265- avg_factor_clone = avg_factor .clone ()
266- tmp_avg_factor = paddle .distributed .all_reduce (avg_factor_clone )
267- if tmp_avg_factor is not None :
268- avg_factor = tmp_avg_factor
269- else :
270- avg_factor = avg_factor_clone
258+ paddle .distributed .all_reduce (avg_factor )
271259 avg_factor = paddle .clip (
272260 avg_factor / paddle .distributed .get_world_size (), min = 1 )
273261 except :
@@ -408,15 +396,8 @@ def get_loss(self, head_outs, gt_meta):
408396 num_level_anchors )
409397 num_total_pos = sum (pos_num_l )
410398 try :
411- cloned_num_total_pos = num_total_pos .clone ()
412- reduced_cloned_num_total_pos = paddle .distributed .all_reduce (
413- cloned_num_total_pos )
414- if reduced_cloned_num_total_pos is not None :
415- num_total_pos = reduced_cloned_num_total_pos / paddle .distributed .get_world_size (
416- )
417- else :
418- num_total_pos = cloned_num_total_pos / paddle .distributed .get_world_size (
419- )
399+ paddle .distributed .all_reduce (num_total_pos )
400+ num_total_pos = num_total_pos / paddle .distributed .get_world_size ()
420401 except :
421402 num_total_pos = max (num_total_pos , 1 )
422403
@@ -494,12 +475,7 @@ def get_loss(self, head_outs, gt_meta):
494475
495476 avg_factor = sum (avg_factor )
496477 try :
497- avg_factor_clone = avg_factor .clone ()
498- tmp_avg_factor = paddle .distributed .all_reduce (avg_factor_clone )
499- if tmp_avg_factor is not None :
500- avg_factor = tmp_avg_factor
501- else :
502- avg_factor = avg_factor_clone
478+ paddle .distributed .all_reduce (avg_factor )
503479 avg_factor = paddle .clip (
504480 avg_factor / paddle .distributed .get_world_size (), min = 1 )
505481 except :
0 commit comments