Skip to content

Commit 905a3c4

Browse files
committed
Attempt to try tf2
1 parent b80aa0c commit 905a3c4

File tree

2 files changed

+35
-66
lines changed

2 files changed

+35
-66
lines changed

examples/keras-perf/cnn.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
from keras import backend as K
2-
from keras.datasets import mnist
3-
from keras.models import Sequential
4-
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten
5-
from keras.utils import np_utils
6-
from wandb.keras import WandbCallback
7-
from keras.callbacks import TensorBoard
8-
import wandb
91
import os
102
import tensorflow as tf
3+
import wandb
4+
from wandb.keras import WandbCallback
115

126
run = wandb.init()
137
config = run.config
@@ -20,7 +14,7 @@
2014
config.img_height = 28
2115
config.epochs = 4
2216

23-
(X_train, y_train), (X_test, y_test) = mnist.load_data()
17+
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
2418

2519
X_train = X_train.astype('float32')
2620
X_train /= 255.
@@ -34,68 +28,35 @@
3428
X_test.shape[0], config.img_width, config.img_height, 1)
3529

3630
# one hot encode outputs
37-
y_train = np_utils.to_categorical(y_train)
38-
y_test = np_utils.to_categorical(y_test)
31+
y_train = tf.keras.utils.to_categorical(y_train)
32+
y_test = tf.keras.utils.to_categorical(y_test)
3933
num_classes = y_test.shape[1]
4034
labels = range(10)
4135

4236
# build model
43-
model = Sequential()
44-
model.add(Conv2D(32,
45-
(config.first_layer_conv_width, config.first_layer_conv_height),
46-
input_shape=(28, 28, 1),
47-
activation='relu'))
48-
model.add(MaxPooling2D(pool_size=(2, 2)))
49-
model.add(Flatten())
50-
model.add(Dense(config.dense_layer_size, activation='relu'))
51-
model.add(Dense(num_classes, activation='softmax'))
52-
37+
model = tf.keras.models.Sequential()
38+
model.add(tf.keras.layers.Conv2D(32,
39+
(config.first_layer_conv_width,
40+
config.first_layer_conv_height),
41+
input_shape=(28, 28, 1),
42+
activation='relu'))
43+
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
44+
model.add(tf.keras.layers.Flatten())
45+
model.add(tf.keras.layers.Dense(config.dense_layer_size, activation='relu'))
46+
model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
5347
model.compile(loss='categorical_crossentropy', optimizer='adam',
54-
metrics=['accuracy'], weighted_metrics=['accuracy'])
55-
56-
48+
metrics=['accuracy'])
49+
# log the number of total parameters
50+
config.total_params = model.count_params()
51+
print("Total params: ", config.total_params)
5752
model.fit(X_train, y_train, validation_data=(X_test, y_test),
5853
epochs=config.epochs,
5954
callbacks=[WandbCallback(data_type="image", save_model=False),
60-
TensorBoard(log_dir=wandb.run.dir)])
61-
62-
63-
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
64-
"""
65-
Freezes the state of a session into a pruned computation graph.
66-
67-
Creates a new computation graph where variable nodes are replaced by
68-
constants taking their current value in the session. The new graph will be
69-
pruned so subgraphs that are not necessary to compute the requested
70-
outputs are removed.
71-
@param session The TensorFlow session to be frozen.
72-
@param keep_var_names A list of variable names that should not be frozen,
73-
or None to freeze all the variables in the graph.
74-
@param output_names Names of the relevant graph outputs.
75-
@param clear_devices Remove the device directives from the graph for better portability.
76-
@return The frozen graph definition.
77-
"""
78-
graph = session.graph
79-
with graph.as_default():
80-
freeze_var_names = list(
81-
set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
82-
output_names = output_names or []
83-
output_names += [v.op.name for v in tf.global_variables()]
84-
input_graph_def = graph.as_graph_def()
85-
if clear_devices:
86-
for node in input_graph_def.node:
87-
node.device = ""
88-
frozen_graph = tf.graph_util.convert_variables_to_constants(
89-
session, input_graph_def, output_names, freeze_var_names)
90-
return frozen_graph
91-
92-
93-
model.save("cnn.h5")
94-
saver = tf.train.Saver()
95-
sess = K.get_session()
96-
saver.save(sess, './keras_model.ckpt')
97-
tf.train.write_graph(sess.graph_def, '.', 'keras_model.pbtxt')
98-
99-
frozen_graph = freeze_session(K.get_session(), output_names=[
100-
out.op.name for out in model.outputs])
101-
tf.train.write_graph(frozen_graph, ".", "cnn.pb", as_text=False)
55+
tf.keras.callbacks.TensorBoard(profile_batch=3)])
56+
model.save('cnn.h5')
57+
58+
# Convert to TensorFlow Lite model.
59+
converter = tf.lite.TFLiteConverter.from_keras_model_file('cnn.h5')
60+
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
61+
tflite_model = converter.convert()
62+
open("cnn.tflite", "wb").write(tflite_model)

examples/keras-perf/train.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
virtualenv --no-site-packages venv
3+
. venv/bin/activate
4+
pip install tensorflow-gpu==2.0.0b1 wandb pillow
5+
6+
python cnn.py
7+
8+
deactivate

0 commit comments

Comments
 (0)