Skip to content

Commit b9ae39c

Browse files
committed
Simply train.py
1 parent db6938a commit b9ae39c

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mobile/tfjs-emotion/train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
# Import layers
12
from keras.layers import Dense, Flatten
23
from keras.models import Sequential
34
from keras.callbacks import Callback
4-
from keras import backend as K
55
import pandas as pd
66
import numpy as np
77
import cv2
8-
from PIL import Image
98
import keras
109
import subprocess
1110
import os
@@ -17,13 +16,16 @@
1716
run = wandb.init()
1817
config = run.config
1918

19+
# set hyperparameters
2020
config.batch_size = 32
2121
config.num_epochs = 5
2222

2323
input_shape = (48, 48, 1)
2424

2525

2626
class Perf(Callback):
27+
"""Performance callback for logging inference time"""
28+
2729
def __init__(self, testX):
2830
self.testX = testX
2931

@@ -38,6 +40,7 @@ def on_epoch_end(self, epoch, logs):
3840

3941

4042
def load_fer2013():
43+
"""Load the emotion dataset"""
4144
if not os.path.exists("fer2013"):
4245
print("Downloading the face emotion dataset...")
4346
subprocess.check_output(
@@ -64,32 +67,29 @@ def load_fer2013():
6467

6568
return train_faces, train_emotions, val_faces, val_emotions
6669

67-
# loading dataset
68-
6970

71+
# loading dataset
7072
train_faces, train_emotions, val_faces, val_emotions = load_fer2013()
7173
num_samples, num_classes = train_emotions.shape
7274

7375
train_faces /= 255.
7476
val_faces /= 255.
7577

78+
# Define the model here, CHANGEME
7679
model = Sequential()
77-
model.add(keras.layers.Conv2D(
78-
32, (3, 3), input_shape=(48, 48, 1), activation="relu"))
79-
model.add(keras.layers.MaxPooling2D())
80-
model.add(keras.layers.Conv2D(64, (3, 3), activation="relu"))
81-
model.add(keras.layers.MaxPooling2D())
8280
model.add(Flatten(input_shape=input_shape))
8381
model.add(Dense(num_classes, activation="softmax"))
8482
model.compile(optimizer='adam', loss='categorical_crossentropy',
8583
metrics=['accuracy'])
84+
85+
# log the number of total parameters
8686
config.total_params = model.count_params()
87-
print("Total", config.total_params)
8887
model.fit(train_faces, train_emotions, batch_size=config.batch_size,
8988
epochs=config.num_epochs, verbose=1, callbacks=[
9089
Perf(val_faces),
9190
WandbCallback(data_type="image", labels=[
9291
"Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"])
9392
], validation_data=(val_faces, val_emotions))
9493

94+
# save the model
9595
model.save("emotion.h5")

0 commit comments

Comments
 (0)