Skip to content

Commit 5ee150c

Browse files
committed
Add total params to cifar
1 parent 3e0e03f commit 5ee150c

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

examples/keras-cifar/cifar-cnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
model.compile(loss='categorical_crossentropy',
3838
optimizer=tf.keras.optimizers.Adam(config.learn_rate),
3939
metrics=['accuracy'])
40+
# log the number of total parameters
41+
config.total_params = model.count_params()
42+
print("Total params: ", config.total_params)
4043

4144
model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test),
4245
callbacks=[wandb.keras.WandbCallback(data_type="image", labels=class_names, save_model=False)])

examples/keras-cifar/cifar-gen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
model.compile(loss='categorical_crossentropy',
4141
optimizer="adam",
4242
metrics=['accuracy'])
43-
43+
# log the number of total parameters
44+
config.total_params = model.count_params()
45+
print("Total params: ", config.total_params)
4446
X_train = X_train.astype('float32') / 255.
4547
X_test = X_test.astype('float32') / 255.
4648

examples/keras-cifar/cifar-transfer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
model.compile(loss='categorical_crossentropy',
4646
optimizer=Adam(config.learn_rate),
4747
metrics=['accuracy'])
48-
48+
# log the number of total parameters
49+
config.total_params = model.count_params()
50+
print("Total params: ", config.total_params)
4951
model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test),
5052
callbacks=[WandbCallback(data_type="image", labels=class_names, save_model=False)])

examples/keras-cifar/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
model.compile(loss='mse',
2525
optimizer=tf.keras.optimizers.Adam(config.learn_rate),
2626
metrics=['accuracy'])
27-
27+
# log the number of total parameters
28+
config.total_params = model.count_params()
29+
print("Total params: ", config.total_params)
2830
model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test),
2931
callbacks=[wandb.keras.WandbCallback(data_type="image", labels=class_names, save_model=False)])

0 commit comments

Comments
 (0)