Skip to content

Commit 0e91841

Browse files
committed
Update training params
1 parent 2136e02 commit 0e91841

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

examples/vision/image_classification_from_scratch.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,12 @@
8484
"""
8585

8686
image_size = (180, 180)
87-
batch_size = 32
87+
batch_size = 128
8888

89-
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
89+
train_ds, val_ds = tf.keras.preprocessing.image_dataset_from_directory(
9090
"PetImages",
9191
validation_split=0.2,
92-
subset="training",
93-
seed=1337,
94-
image_size=image_size,
95-
batch_size=batch_size,
96-
)
97-
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
98-
"PetImages",
99-
validation_split=0.2,
100-
subset="validation",
92+
subset="both",
10193
seed=1337,
10294
image_size=image_size,
10395
batch_size=batch_size,
@@ -297,7 +289,7 @@ def make_model(input_shape, num_classes):
297289
## Train the model
298290
"""
299291

300-
epochs = 50
292+
epochs = 25
301293

302294
callbacks = [
303295
keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
@@ -306,7 +298,7 @@ def make_model(input_shape, num_classes):
306298
optimizer=keras.optimizers.Adam(1e-3),
307299
loss="binary_crossentropy",
308300
metrics=["accuracy"],
309-
jit_compile=True, # Enable XLA compilation
301+
jit_compile=True, # Enable XLA compilation for faster execution
310302
)
311303
model.fit(
312304
train_ds,
@@ -316,7 +308,7 @@ def make_model(input_shape, num_classes):
316308
)
317309

318310
"""
319-
We get to ~96% validation accuracy after training for 50 epochs on the full dataset.
311+
We get to ~96% validation accuracy after training for 25 epochs on the full dataset.
320312
"""
321313

322314
"""

0 commit comments

Comments
 (0)