@@ -1213,8 +1213,10 @@ def callback(x, *, done_generating=False):
1213
1213
print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
1214
1214
elif self .builder_args .device == "cuda" :
1215
1215
print (prof .key_averages ().table (sort_by = "self_cuda_time_total" ))
1216
- else :
1216
+ elif self . builder_args . device == "xpu" :
1217
1217
print (prof .key_averages ().table (sort_by = "self_xpu_time_total" ))
1218
+ elif self .builder_args .device == "npu" :
1219
+ print (prof .key_averages ().table (sort_by = "self_npu_time_total" ))
1218
1220
prof .export_chrome_trace (f"{ self .profile } .json" )
1219
1221
1220
1222
if start_pos >= max_seq_length :
@@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
1299
1301
)
1300
1302
if torch .cuda .is_available ():
1301
1303
print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
1302
- if torch .xpu .is_available ():
1304
+ elif torch .xpu .is_available ():
1303
1305
print (f"Memory used: { torch .xpu .max_memory_reserved () / 1e9 :.02f} GB" )
1306
+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
1307
+ print (f"Memory used: { torch .npu .max_memory_reserved () / 1e9 :.02f} GB" )
1304
1308
1305
1309
1306
1310
@@ -1595,7 +1599,6 @@ def sample(
1595
1599
1596
1600
return idx_next , probs
1597
1601
1598
-
1599
1602
def run_generator (
1600
1603
args ,
1601
1604
rank : Optional [int ] = None
@@ -1628,8 +1631,10 @@ def run_generator(
1628
1631
)
1629
1632
if torch .cuda .is_available ():
1630
1633
torch .cuda .reset_peak_memory_stats ()
1631
- if torch .xpu .is_available ():
1634
+ elif torch .xpu .is_available ():
1632
1635
torch .xpu .reset_peak_memory_stats ()
1636
+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
1637
+ torch .npu .reset_peak_memory_stats ()
1633
1638
1634
1639
for _ in gen .chat (generator_args ):
1635
1640
pass
0 commit comments