Skip to content

Commit b0620a7

Browse files
authored
correct the use of all_reduce (PaddlePaddle#7108) (PaddlePaddle#7199)
1 parent e066d8d commit b0620a7

File tree

2 files changed

+7
-36
lines changed

2 files changed

+7
-36
lines changed

ppdet/modeling/heads/gfl_head.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,12 +388,7 @@ def get_loss(self, gfl_head_outs, gt_meta):
388388

389389
avg_factor = sum(avg_factor)
390390
try:
391-
avg_factor_clone = avg_factor.clone()
392-
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
393-
if tmp_avg_factor is not None:
394-
avg_factor = tmp_avg_factor
395-
else:
396-
avg_factor = avg_factor_clone
391+
paddle.distributed.all_reduce(avg_factor)
397392
avg_factor = paddle.clip(
398393
avg_factor / paddle.distributed.get_world_size(), min=1)
399394
except:

ppdet/modeling/heads/simota_head.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,8 @@ def get_loss(self, head_outs, gt_meta):
179179
num_level_anchors)
180180
num_total_pos = sum(pos_num_l)
181181
try:
182-
cloned_num_total_pos = num_total_pos.clone()
183-
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
184-
cloned_num_total_pos)
185-
if reduced_cloned_num_total_pos is not None:
186-
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
187-
)
188-
else:
189-
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
190-
)
182+
paddle.distributed.all_reduce(num_total_pos)
183+
num_total_pos = num_total_pos / paddle.distributed.get_world_size()
191184
except:
192185
num_total_pos = max(num_total_pos, 1)
193186

@@ -262,12 +255,7 @@ def get_loss(self, head_outs, gt_meta):
262255

263256
avg_factor = sum(avg_factor)
264257
try:
265-
avg_factor_clone = avg_factor.clone()
266-
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
267-
if tmp_avg_factor is not None:
268-
avg_factor = tmp_avg_factor
269-
else:
270-
avg_factor = avg_factor_clone
258+
paddle.distributed.all_reduce(avg_factor)
271259
avg_factor = paddle.clip(
272260
avg_factor / paddle.distributed.get_world_size(), min=1)
273261
except:
@@ -408,15 +396,8 @@ def get_loss(self, head_outs, gt_meta):
408396
num_level_anchors)
409397
num_total_pos = sum(pos_num_l)
410398
try:
411-
cloned_num_total_pos = num_total_pos.clone()
412-
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
413-
cloned_num_total_pos)
414-
if reduced_cloned_num_total_pos is not None:
415-
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
416-
)
417-
else:
418-
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
419-
)
399+
paddle.distributed.all_reduce(num_total_pos)
400+
num_total_pos = num_total_pos / paddle.distributed.get_world_size()
420401
except:
421402
num_total_pos = max(num_total_pos, 1)
422403

@@ -494,12 +475,7 @@ def get_loss(self, head_outs, gt_meta):
494475

495476
avg_factor = sum(avg_factor)
496477
try:
497-
avg_factor_clone = avg_factor.clone()
498-
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
499-
if tmp_avg_factor is not None:
500-
avg_factor = tmp_avg_factor
501-
else:
502-
avg_factor = avg_factor_clone
478+
paddle.distributed.all_reduce(avg_factor)
503479
avg_factor = paddle.clip(
504480
avg_factor / paddle.distributed.get_world_size(), min=1)
505481
except:

0 commit comments

Comments
 (0)