Skip to content

Commit 5a1faff

Browse files
snnnk-w-w
authored andcommitted
official/mnist: support savedmodel (tensorflow#2967)
With examples, and updates to the README
1 parent f40184c commit 5a1faff

File tree

5 files changed

+50
-1
lines changed

5 files changed

+50
-1
lines changed

official/mnist/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,36 @@ python mnist.py
2020

2121
The model will begin training and will automatically evaluate itself on the
2222
validation data.
23+
24+
## Exporting the model
25+
26+
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
27+
28+
```
29+
python mnist.py --export_dir /tmp/mnist_saved_model
30+
```
31+
32+
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
33+
34+
**Getting predictions with SavedModel**
35+
Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
36+
37+
```
38+
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image_raw=examples.npy
39+
```
40+
41+
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
42+
43+
The output should look similar to below:
44+
```
45+
Result for output key classes:
46+
[5 3]
47+
Result for output key probabilities:
48+
[[ 1.53558474e-07 1.95694142e-13 1.31193523e-09 5.47467265e-03
49+
5.85711526e-22 9.94520664e-01 3.48423509e-06 2.65365645e-17
50+
9.78631419e-07 3.15522470e-08]
51+
[ 1.22413359e-04 5.87615965e-08 1.72251271e-06 9.39960718e-01
52+
3.30306928e-11 2.87386645e-02 2.82353517e-02 8.21146413e-18
53+
2.52568233e-03 4.15460236e-04]]
54+
```
55+

official/mnist/example3.png

368 Bytes
Loading

official/mnist/example5.png

367 Bytes
Loading

official/mnist/examples.npy

12.3 KB
Binary file not shown.

official/mnist/mnist.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
'with CPU. If left unspecified, the data format will be chosen '
5959
'automatically based on whether TensorFlow was built for CPU or GPU.')
6060

61+
parser.add_argument(
62+
'--export_dir',
63+
type=str,
64+
help='The directory where the exported SavedModel will be stored.')
6165

6266
def train_dataset(data_dir):
6367
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
@@ -152,6 +156,9 @@ def mnist_model(inputs, mode, data_format):
152156

153157
def mnist_model_fn(features, labels, mode, params):
154158
"""Model function for MNIST."""
159+
if mode == tf.estimator.ModeKeys.PREDICT and isinstance(features,dict):
160+
features = features['image_raw']
161+
155162
logits = mnist_model(features, mode, params['data_format'])
156163

157164
predictions = {
@@ -160,7 +167,9 @@ def mnist_model_fn(features, labels, mode, params):
160167
}
161168

162169
if mode == tf.estimator.ModeKeys.PREDICT:
163-
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
170+
export_outputs={'classify': tf.estimator.export.PredictOutput(predictions)}
171+
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions,
172+
export_outputs=export_outputs)
164173

165174
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
166175

@@ -222,6 +231,13 @@ def eval_input_fn():
222231
print()
223232
print('Evaluation results:\n\t%s' % eval_results)
224233

234+
# Export the model
235+
if FLAGS.export_dir is not None:
236+
image = tf.placeholder(tf.float32,[None, 28, 28])
237+
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
238+
{"image_raw":image})
239+
mnist_classifier.export_savedmodel(FLAGS.export_dir, serving_input_fn)
240+
225241

226242
if __name__ == '__main__':
227243
tf.logging.set_verbosity(tf.logging.INFO)

0 commit comments

Comments
 (0)