Skip to content

Commit 47544a5

Browse files
committed
fix eval batch & non-compute_loss in pipeline
1 parent 5c91eb6 commit 47544a5

File tree

1 file changed

+57
-31
lines changed

1 file changed

+57
-31
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _check_data_valid(self, data):
165165
batch_size = data.shape[0]
166166
assert self._micro_batch_size * self._acc_steps == batch_size, (
167167
"batch_size needs to be divisible by micro_batch_size. Currently, "
168-
f"batch_size = {batch_size}, micro_batch_size = {self._micro_batch_size}, accumulate_steps = {self._acc_steps}."
168+
f"batch_size = {batch_size}, micro_batch_size = {self._micro_batch_size}, accumulate_steps = {self._acc_steps} data_shape= {data.shape}."
169169
)
170170

171171

@@ -413,6 +413,7 @@ def __init__(self, layers, hcg, strategy):
413413
self.loss_fn_idx = 0
414414

415415
self._compute_loss = True
416+
self._return_host_tensor = False
416417
self.callbacks = pipeline_parallel_callbacks_
417418

418419
logger.info(
@@ -991,13 +992,16 @@ def train_batch(
991992

992993
return train_loss
993994

994-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
995+
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
995996
self.user_hooks_enabled = False
996997
# reset the virtual pp rank for each run
997998
self.set_virtual_pipeline_rank(0)
998999

9991000
self._layers.eval()
1001+
origin_compute_loss = self._compute_loss
10001002
self._compute_loss = compute_loss
1003+
origin_return_host_tensor = self._return_host_tensor
1004+
self._return_host_tensor = return_host_tensor
10011005

10021006
# store data id for micro_batch
10031007
self.micro_batch_id = 0
@@ -1070,11 +1074,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10701074
)
10711075

10721076
if self._compute_loss:
1073-
self.train_loss = self._broadcast_final_loss()
1077+
train_loss = self._broadcast_final_loss()
10741078
else:
1075-
self.train_loss = output_buffers
1079+
train_loss = output_buffers
10761080

1077-
return self.train_loss
1081+
self._compute_loss = origin_compute_loss
1082+
self._return_host_tensor = origin_return_host_tensor
1083+
return train_loss
10781084

10791085
def _maybe_loss_compute(
10801086
self, output_tensor, micro_dataset, overlap_schedule_mode=False
@@ -1384,6 +1390,19 @@ def _optimizer_step(self):
13841390
if self.lr_scheduler:
13851391
self.lr_scheduler.step()
13861392

1393+
def mark_release_tensors(self, output_tensor, can_release=True):
1394+
if isinstance(output_tensor, (tuple, list)):
1395+
for t in output:
1396+
setattr(t, "can_release", can_release)
1397+
if self._return_host_tensor:
1398+
host_tensor = t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
1399+
host_tensor._share_buffer_to(t)
1400+
else:
1401+
setattr(output_tensor, "can_release", can_release)
1402+
if self._return_host_tensor:
1403+
host_tensor = output_tensor.pin_memory() if hasattr(output_tensor, "pin_memory") else output_tensor.cpu()
1404+
host_tensor._share_buffer_to(output_tensor)
1405+
13871406
def _release_output(self, output):
13881407
def can_free(t):
13891408
return (
@@ -1655,10 +1674,12 @@ def _get_forward_input(self, virtual_pp_rank):
16551674
assert hasattr(self, 'output_tensors')
16561675
if not self._forward_only:
16571676
assert hasattr(self, 'output_tensor_grads')
1658-
assert len(self.input_tensors[virtual_pp_rank]) == (
1659-
len(self.output_tensors[virtual_pp_rank]) + 1
1660-
)
1661-
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1677+
assert len(self.input_tensors[virtual_pp_rank]) == (
1678+
len(self.output_tensors[virtual_pp_rank]) + 1
1679+
)
1680+
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1681+
else:
1682+
input_tensor = self.input_tensors[virtual_pp_rank].pop()
16621683
return input_tensor
16631684

16641685
def _store_forward_outputs(
@@ -1673,11 +1694,9 @@ def _store_forward_outputs(
16731694
self.schedule_chunks[virtual_pp_rank].append(schedule_chunk)
16741695
if self.is_pipeline_last_stage():
16751696
self.loss_fn_chunks.append(loss_fn_node)
1676-
1677-
if self._forward_only:
1678-
# no need to store tensor for backward
1679-
self.input_tensors[virtual_pp_rank].pop()
1680-
self.output_tensors[virtual_pp_rank].pop()
1697+
# save output_tensors for return value of eval batch
1698+
if not self._compute_loss:
1699+
self.mark_release_tensors(output_tensor, False)
16811700

16821701
def _forward_step_helper(
16831702
self,
@@ -1981,7 +2000,7 @@ def forward_backward_pipeline(
19812000
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
19822001
if not compute_loss:
19832002
assert (
1984-
not forward_only
2003+
forward_only
19852004
), "compute_loss can only be set to False when forward_only is set to True"
19862005

19872006
if static_scheduler:
@@ -2758,12 +2777,12 @@ def backward_async_comm(
27582777
if self._enable_timer:
27592778
self.timers("broadcast_final_loss").start()
27602779
with paddle.amp.auto_cast(enable=False):
2761-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
2780+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
27622781
if self._enable_timer:
27632782
self.timers("broadcast_final_loss").stop()
27642783
else:
2765-
# else just return all intermediate output tensor for all micro steps
2766-
train_loss = self.output_tensors
2784+
# else just return logits without loss func calc
2785+
train_loss_or_logits = self.output_tensors.pop()
27672786

27682787
if self._clear_every_step_cache:
27692788
self._p2p_helper.clear_meta_cache()
@@ -2781,7 +2800,7 @@ def backward_async_comm(
27812800
), "p2p dynamic_cnt should equal to send_recv_meta_list"
27822801
self._p2p_helper._dynamic_cnt = 0
27832802

2784-
return train_loss
2803+
return train_loss_or_logits
27852804

27862805
def train_batch(
27872806
self,
@@ -2812,13 +2831,16 @@ def train_batch(
28122831

28132832
return train_loss
28142833

2815-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
2834+
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
28162835
self.user_hooks_enabled = False
28172836
# reset the virtual pp rank for each run
28182837
self.set_virtual_pipeline_rank(0)
28192838

28202839
self._layers.eval()
2840+
origin_compute_loss = self._compute_loss
28212841
self._compute_loss = compute_loss
2842+
origin_return_host_tensor = self._return_host_tensor
2843+
self._return_host_tensor = return_host_tensor
28222844

28232845
# check loss_fn_idx is valid and loss_fn exists
28242846
assert (
@@ -2827,7 +2849,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
28272849
), f"loss function {loss_fn_idx} should exist to compute loss"
28282850
self.loss_fn_idx = loss_fn_idx
28292851

2830-
return self.forward_backward_pipeline(data, None, forward_only=True)
2852+
train_loss_or_logits = self.forward_backward_pipeline(data, None, forward_only=True, compute_loss=compute_loss)
2853+
self._init_buffers()
2854+
self._compute_loss = origin_compute_loss
2855+
self._return_host_tensor = origin_return_host_tensor
2856+
return train_loss_or_logits
28312857

28322858
def get_static_scheduler(self):
28332859
return self.forward_backward_pipeline(
@@ -2918,7 +2944,7 @@ def forward_backward_pipeline(
29182944
get_sync_logger().info("start forward_backward_pipeline")
29192945
if not compute_loss:
29202946
assert (
2921-
not forward_only
2947+
forward_only
29222948
), "compute_loss can only be set to False when forward_only is set to True"
29232949

29242950
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
@@ -3065,12 +3091,12 @@ def forward_backward_pipeline(
30653091
if self._enable_timer:
30663092
self.timers("broadcast_final_loss").start()
30673093
with paddle.amp.auto_cast(enable=False):
3068-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3094+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
30693095
if self._enable_timer:
30703096
self.timers("broadcast_final_loss").stop()
30713097
else:
3072-
# else just return all intermediate output tensor for all micro steps
3073-
train_loss = self.output_tensors
3098+
# else just return logits without loss func calc
3099+
train_loss_or_logits = self.output_tensors.pop()
30743100

30753101
if self._clear_every_step_cache:
30763102
self._p2p_helper.clear_meta_cache()
@@ -3081,7 +3107,7 @@ def forward_backward_pipeline(
30813107
get_sync_logger().info("end forward_backward_pipeline")
30823108
self.processed_steps += 1
30833109
self._check_user_hooks_status_at_step_end()
3084-
return train_loss
3110+
return train_loss_or_logits
30853111

30863112

30873113
class OffloadQueue(queue.Queue):
@@ -3147,7 +3173,7 @@ def forward_backward_pipeline(
31473173
self._reset_user_hooks_status()
31483174
if not compute_loss:
31493175
assert (
3150-
not forward_only
3176+
forward_only
31513177
), "compute_loss can only be set to False when forward_only is set to True"
31523178
assert (
31533179
self._using_cache
@@ -3406,12 +3432,12 @@ def forward_backward_pipeline(
34063432
if self._enable_timer:
34073433
self.timers("broadcast_final_loss").start()
34083434
with paddle.amp.auto_cast(enable=False):
3409-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3435+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
34103436
if self._enable_timer:
34113437
self.timers("broadcast_final_loss").stop()
34123438
else:
3413-
# else just return all intermediate output tensor for all micro steps
3414-
train_loss = self.output_tensors
3439+
# else just return logits without loss func calc
3440+
train_loss_or_logits = self.output_tensors.pop()
34153441

34163442
if self._clear_every_step_cache:
34173443
self._p2p_helper.clear_meta_cache()
@@ -3422,4 +3448,4 @@ def forward_backward_pipeline(
34223448
get_sync_logger().info("end forward_backward_pipeline")
34233449
self.processed_steps += 1
34243450
self._check_user_hooks_status_at_step_end()
3425-
return train_loss
3451+
return train_loss_or_logits

0 commit comments

Comments
 (0)