@@ -165,7 +165,7 @@ def _check_data_valid(self, data):
165
165
batch_size = data .shape [0 ]
166
166
assert self ._micro_batch_size * self ._acc_steps == batch_size , (
167
167
"batch_size needs to be divisible by micro_batch_size. Currently, "
168
- f"batch_size = { batch_size } , micro_batch_size = { self ._micro_batch_size } , accumulate_steps = { self ._acc_steps } ."
168
+ f"batch_size = { batch_size } , micro_batch_size = { self ._micro_batch_size } , accumulate_steps = { self ._acc_steps } data_shape= { data . shape } ."
169
169
)
170
170
171
171
@@ -413,6 +413,7 @@ def __init__(self, layers, hcg, strategy):
413
413
self .loss_fn_idx = 0
414
414
415
415
self ._compute_loss = True
416
+ self ._return_host_tensor = False
416
417
self .callbacks = pipeline_parallel_callbacks_
417
418
418
419
logger .info (
@@ -991,13 +992,16 @@ def train_batch(
991
992
992
993
return train_loss
993
994
994
- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
995
+ def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
995
996
self .user_hooks_enabled = False
996
997
# reset the virtual pp rank for each run
997
998
self .set_virtual_pipeline_rank (0 )
998
999
999
1000
self ._layers .eval ()
1001
+ origin_compute_loss = self ._compute_loss
1000
1002
self ._compute_loss = compute_loss
1003
+ origin_return_host_tensor = self ._return_host_tensor
1004
+ self ._return_host_tensor = return_host_tensor
1001
1005
1002
1006
# store data id for micro_batch
1003
1007
self .micro_batch_id = 0
@@ -1070,11 +1074,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1070
1074
)
1071
1075
1072
1076
if self ._compute_loss :
1073
- self . train_loss = self ._broadcast_final_loss ()
1077
+ train_loss = self ._broadcast_final_loss ()
1074
1078
else :
1075
- self . train_loss = output_buffers
1079
+ train_loss = output_buffers
1076
1080
1077
- return self .train_loss
1081
+ self ._compute_loss = origin_compute_loss
1082
+ self ._return_host_tensor = origin_return_host_tensor
1083
+ return train_loss
1078
1084
1079
1085
def _maybe_loss_compute (
1080
1086
self , output_tensor , micro_dataset , overlap_schedule_mode = False
@@ -1384,6 +1390,19 @@ def _optimizer_step(self):
1384
1390
if self .lr_scheduler :
1385
1391
self .lr_scheduler .step ()
1386
1392
1393
+ def mark_release_tensors (self , output_tensor , can_release = True ):
1394
+ if isinstance (output_tensor , (tuple , list )):
1395
+ for t in output :
1396
+ setattr (t , "can_release" , can_release )
1397
+ if self ._return_host_tensor :
1398
+ host_tensor = t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1399
+ host_tensor ._share_buffer_to (t )
1400
+ else :
1401
+ setattr (output_tensor , "can_release" , can_release )
1402
+ if self ._return_host_tensor :
1403
+ host_tensor = output_tensor .pin_memory () if hasattr (output_tensor , "pin_memory" ) else output_tensor .cpu ()
1404
+ host_tensor ._share_buffer_to (output_tensor )
1405
+
1387
1406
def _release_output (self , output ):
1388
1407
def can_free (t ):
1389
1408
return (
@@ -1655,10 +1674,12 @@ def _get_forward_input(self, virtual_pp_rank):
1655
1674
assert hasattr (self , 'output_tensors' )
1656
1675
if not self ._forward_only :
1657
1676
assert hasattr (self , 'output_tensor_grads' )
1658
- assert len (self .input_tensors [virtual_pp_rank ]) == (
1659
- len (self .output_tensors [virtual_pp_rank ]) + 1
1660
- )
1661
- input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1677
+ assert len (self .input_tensors [virtual_pp_rank ]) == (
1678
+ len (self .output_tensors [virtual_pp_rank ]) + 1
1679
+ )
1680
+ input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1681
+ else :
1682
+ input_tensor = self .input_tensors [virtual_pp_rank ].pop ()
1662
1683
return input_tensor
1663
1684
1664
1685
def _store_forward_outputs (
@@ -1673,11 +1694,9 @@ def _store_forward_outputs(
1673
1694
self .schedule_chunks [virtual_pp_rank ].append (schedule_chunk )
1674
1695
if self .is_pipeline_last_stage ():
1675
1696
self .loss_fn_chunks .append (loss_fn_node )
1676
-
1677
- if self ._forward_only :
1678
- # no need to store tensor for backward
1679
- self .input_tensors [virtual_pp_rank ].pop ()
1680
- self .output_tensors [virtual_pp_rank ].pop ()
1697
+ # save output_tensors for return value of eval batch
1698
+ if not self ._compute_loss :
1699
+ self .mark_release_tensors (output_tensor , False )
1681
1700
1682
1701
def _forward_step_helper (
1683
1702
self ,
@@ -1981,7 +2000,7 @@ def forward_backward_pipeline(
1981
2000
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
1982
2001
if not compute_loss :
1983
2002
assert (
1984
- not forward_only
2003
+ forward_only
1985
2004
), "compute_loss can only be set to False when forward_only is set to True"
1986
2005
1987
2006
if static_scheduler :
@@ -2758,12 +2777,12 @@ def backward_async_comm(
2758
2777
if self ._enable_timer :
2759
2778
self .timers ("broadcast_final_loss" ).start ()
2760
2779
with paddle .amp .auto_cast (enable = False ):
2761
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
2780
+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
2762
2781
if self ._enable_timer :
2763
2782
self .timers ("broadcast_final_loss" ).stop ()
2764
2783
else :
2765
- # else just return all intermediate output tensor for all micro steps
2766
- train_loss = self .output_tensors
2784
+ # else just return logits without loss func calc
2785
+ train_loss_or_logits = self .output_tensors . pop ()
2767
2786
2768
2787
if self ._clear_every_step_cache :
2769
2788
self ._p2p_helper .clear_meta_cache ()
@@ -2781,7 +2800,7 @@ def backward_async_comm(
2781
2800
), "p2p dynamic_cnt should equal to send_recv_meta_list"
2782
2801
self ._p2p_helper ._dynamic_cnt = 0
2783
2802
2784
- return train_loss
2803
+ return train_loss_or_logits
2785
2804
2786
2805
def train_batch (
2787
2806
self ,
@@ -2812,13 +2831,16 @@ def train_batch(
2812
2831
2813
2832
return train_loss
2814
2833
2815
- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
2834
+ def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
2816
2835
self .user_hooks_enabled = False
2817
2836
# reset the virtual pp rank for each run
2818
2837
self .set_virtual_pipeline_rank (0 )
2819
2838
2820
2839
self ._layers .eval ()
2840
+ origin_compute_loss = self ._compute_loss
2821
2841
self ._compute_loss = compute_loss
2842
+ origin_return_host_tensor = self ._return_host_tensor
2843
+ self ._return_host_tensor = return_host_tensor
2822
2844
2823
2845
# check loss_fn_idx is valid and loss_fn exists
2824
2846
assert (
@@ -2827,7 +2849,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
2827
2849
), f"loss function { loss_fn_idx } should exist to compute loss"
2828
2850
self .loss_fn_idx = loss_fn_idx
2829
2851
2830
- return self .forward_backward_pipeline (data , None , forward_only = True )
2852
+ train_loss_or_logits = self .forward_backward_pipeline (data , None , forward_only = True , compute_loss = compute_loss )
2853
+ self ._init_buffers ()
2854
+ self ._compute_loss = origin_compute_loss
2855
+ self ._return_host_tensor = origin_return_host_tensor
2856
+ return train_loss_or_logits
2831
2857
2832
2858
def get_static_scheduler (self ):
2833
2859
return self .forward_backward_pipeline (
@@ -2918,7 +2944,7 @@ def forward_backward_pipeline(
2918
2944
get_sync_logger ().info ("start forward_backward_pipeline" )
2919
2945
if not compute_loss :
2920
2946
assert (
2921
- not forward_only
2947
+ forward_only
2922
2948
), "compute_loss can only be set to False when forward_only is set to True"
2923
2949
2924
2950
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
@@ -3065,12 +3091,12 @@ def forward_backward_pipeline(
3065
3091
if self ._enable_timer :
3066
3092
self .timers ("broadcast_final_loss" ).start ()
3067
3093
with paddle .amp .auto_cast (enable = False ):
3068
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3094
+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
3069
3095
if self ._enable_timer :
3070
3096
self .timers ("broadcast_final_loss" ).stop ()
3071
3097
else :
3072
- # else just return all intermediate output tensor for all micro steps
3073
- train_loss = self .output_tensors
3098
+ # else just return logits without loss func calc
3099
+ train_loss_or_logits = self .output_tensors . pop ()
3074
3100
3075
3101
if self ._clear_every_step_cache :
3076
3102
self ._p2p_helper .clear_meta_cache ()
@@ -3081,7 +3107,7 @@ def forward_backward_pipeline(
3081
3107
get_sync_logger ().info ("end forward_backward_pipeline" )
3082
3108
self .processed_steps += 1
3083
3109
self ._check_user_hooks_status_at_step_end ()
3084
- return train_loss
3110
+ return train_loss_or_logits
3085
3111
3086
3112
3087
3113
class OffloadQueue (queue .Queue ):
@@ -3147,7 +3173,7 @@ def forward_backward_pipeline(
3147
3173
self ._reset_user_hooks_status ()
3148
3174
if not compute_loss :
3149
3175
assert (
3150
- not forward_only
3176
+ forward_only
3151
3177
), "compute_loss can only be set to False when forward_only is set to True"
3152
3178
assert (
3153
3179
self ._using_cache
@@ -3406,12 +3432,12 @@ def forward_backward_pipeline(
3406
3432
if self ._enable_timer :
3407
3433
self .timers ("broadcast_final_loss" ).start ()
3408
3434
with paddle .amp .auto_cast (enable = False ):
3409
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3435
+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
3410
3436
if self ._enable_timer :
3411
3437
self .timers ("broadcast_final_loss" ).stop ()
3412
3438
else :
3413
- # else just return all intermediate output tensor for all micro steps
3414
- train_loss = self .output_tensors
3439
+ # else just return logits without loss func calc
3440
+ train_loss_or_logits = self .output_tensors . pop ()
3415
3441
3416
3442
if self ._clear_every_step_cache :
3417
3443
self ._p2p_helper .clear_meta_cache ()
@@ -3422,4 +3448,4 @@ def forward_backward_pipeline(
3422
3448
get_sync_logger ().info ("end forward_backward_pipeline" )
3423
3449
self .processed_steps += 1
3424
3450
self ._check_user_hooks_status_at_step_end ()
3425
- return train_loss
3451
+ return train_loss_or_logits
0 commit comments