Skip to content

Commit d2cd996

Browse files
authored
Fix the model parameter reporting to report the number of model parameters. (#139)
1 parent 0e10eeb commit d2cd996

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

training/dist_clm_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def train_loop(args, pipe, device, train_data_loader, test_data_loader):
106106
token_count = train_data_loader.dataset.get_dataset_token_count()
107107

108108
# Get the number of model parameters
109-
param_count = sum(p.numel() for p in pipe.model.parameters() if p.requires_grad)
109+
param_count = sum(p.numel() for p in pipe.model.parameters())
110110

111111
# Report training start
112112
event_reporter.report(object=EventReporter.OBJECT_FINE_TUNE,

0 commit comments

Comments
 (0)