|
1 |
| -from keras import backend as K |
2 |
| -from keras.datasets import mnist |
3 |
| -from keras.models import Sequential |
4 |
| -from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten |
5 |
| -from keras.utils import np_utils |
6 |
| -from wandb.keras import WandbCallback |
7 |
| -from keras.callbacks import TensorBoard |
8 |
| -import wandb |
9 | 1 | import os
|
10 | 2 | import tensorflow as tf
|
| 3 | +import wandb |
| 4 | +from wandb.keras import WandbCallback |
11 | 5 |
|
12 | 6 | run = wandb.init()
|
13 | 7 | config = run.config
|
|
20 | 14 | config.img_height = 28
|
21 | 15 | config.epochs = 4
|
22 | 16 |
|
23 |
| -(X_train, y_train), (X_test, y_test) = mnist.load_data() |
| 17 | +(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() |
24 | 18 |
|
25 | 19 | X_train = X_train.astype('float32')
|
26 | 20 | X_train /= 255.
|
|
34 | 28 | X_test.shape[0], config.img_width, config.img_height, 1)
|
35 | 29 |
|
36 | 30 | # one hot encode outputs
|
37 |
| -y_train = np_utils.to_categorical(y_train) |
38 |
| -y_test = np_utils.to_categorical(y_test) |
| 31 | +y_train = tf.keras.utils.to_categorical(y_train) |
| 32 | +y_test = tf.keras.utils.to_categorical(y_test) |
39 | 33 | num_classes = y_test.shape[1]
|
40 | 34 | labels = range(10)
|
41 | 35 |
|
42 | 36 | # build model
|
43 |
| -model = Sequential() |
44 |
| -model.add(Conv2D(32, |
45 |
| - (config.first_layer_conv_width, config.first_layer_conv_height), |
46 |
| - input_shape=(28, 28, 1), |
47 |
| - activation='relu')) |
48 |
| -model.add(MaxPooling2D(pool_size=(2, 2))) |
49 |
| -model.add(Flatten()) |
50 |
| -model.add(Dense(config.dense_layer_size, activation='relu')) |
51 |
| -model.add(Dense(num_classes, activation='softmax')) |
52 |
| - |
| 37 | +model = tf.keras.models.Sequential() |
| 38 | +model.add(tf.keras.layers.Conv2D(32, |
| 39 | + (config.first_layer_conv_width, |
| 40 | + config.first_layer_conv_height), |
| 41 | + input_shape=(28, 28, 1), |
| 42 | + activation='relu')) |
| 43 | +model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2))) |
| 44 | +model.add(tf.keras.layers.Flatten()) |
| 45 | +model.add(tf.keras.layers.Dense(config.dense_layer_size, activation='relu')) |
| 46 | +model.add(tf.keras.layers.Dense(num_classes, activation='softmax')) |
53 | 47 | model.compile(loss='categorical_crossentropy', optimizer='adam',
|
54 |
| - metrics=['accuracy'], weighted_metrics=['accuracy']) |
55 |
| - |
56 |
| - |
| 48 | + metrics=['accuracy']) |
| 49 | +# log the number of total parameters |
| 50 | +config.total_params = model.count_params() |
| 51 | +print("Total params: ", config.total_params) |
57 | 52 | model.fit(X_train, y_train, validation_data=(X_test, y_test),
|
58 | 53 | epochs=config.epochs,
|
59 | 54 | callbacks=[WandbCallback(data_type="image", save_model=False),
|
60 |
| - TensorBoard(log_dir=wandb.run.dir)]) |
61 |
| - |
62 |
| - |
63 |
| -def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): |
64 |
| - """ |
65 |
| - Freezes the state of a session into a pruned computation graph. |
66 |
| -
|
67 |
| - Creates a new computation graph where variable nodes are replaced by |
68 |
| - constants taking their current value in the session. The new graph will be |
69 |
| - pruned so subgraphs that are not necessary to compute the requested |
70 |
| - outputs are removed. |
71 |
| - @param session The TensorFlow session to be frozen. |
72 |
| - @param keep_var_names A list of variable names that should not be frozen, |
73 |
| - or None to freeze all the variables in the graph. |
74 |
| - @param output_names Names of the relevant graph outputs. |
75 |
| - @param clear_devices Remove the device directives from the graph for better portability. |
76 |
| - @return The frozen graph definition. |
77 |
| - """ |
78 |
| - graph = session.graph |
79 |
| - with graph.as_default(): |
80 |
| - freeze_var_names = list( |
81 |
| - set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) |
82 |
| - output_names = output_names or [] |
83 |
| - output_names += [v.op.name for v in tf.global_variables()] |
84 |
| - input_graph_def = graph.as_graph_def() |
85 |
| - if clear_devices: |
86 |
| - for node in input_graph_def.node: |
87 |
| - node.device = "" |
88 |
| - frozen_graph = tf.graph_util.convert_variables_to_constants( |
89 |
| - session, input_graph_def, output_names, freeze_var_names) |
90 |
| - return frozen_graph |
91 |
| - |
92 |
| - |
93 |
| -model.save("cnn.h5") |
94 |
| -saver = tf.train.Saver() |
95 |
| -sess = K.get_session() |
96 |
| -saver.save(sess, './keras_model.ckpt') |
97 |
| -tf.train.write_graph(sess.graph_def, '.', 'keras_model.pbtxt') |
98 |
| - |
99 |
| -frozen_graph = freeze_session(K.get_session(), output_names=[ |
100 |
| - out.op.name for out in model.outputs]) |
101 |
| -tf.train.write_graph(frozen_graph, ".", "cnn.pb", as_text=False) |
| 55 | + tf.keras.callbacks.TensorBoard(profile_batch=3)]) |
| 56 | +model.save('cnn.h5') |
| 57 | + |
| 58 | +# Convert to TensorFlow Lite model. |
| 59 | +converter = tf.lite.TFLiteConverter.from_keras_model_file('cnn.h5') |
| 60 | +converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] |
| 61 | +tflite_model = converter.convert() |
| 62 | +open("cnn.tflite", "wb").write(tflite_model) |
0 commit comments