Skip to content

Commit 0f5803b

Browse files
authored
Merge pull request tensorflow#3028 from tensorflow/mhyttsten-patch-2
Update blog_custom_estimators.py
2 parents d19587e + 688143e commit 0f5803b

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

samples/outreach/blogs/blog_custom_estimators.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# ==============================================================================
1515

1616
# This is the complete code for the following blogpost:
17-
# https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
18-
# (https://goo.gl/Ujm2Ep)
17+
# https://developers.googleblog.com/2017/12/creating-custom-estimators-in-tensorflow.html
1918

2019
import tensorflow as tf
2120
import os
@@ -116,7 +115,7 @@ def my_model_fn(
116115
h2 = tf.layers.Dense(10, activation=tf.nn.relu)(h1)
117116

118117
# Output 'logits' layer is three number = probability distribution
119-
# between Iris Sentosa, Versicolor, and Viginica
118+
# between Iris Setosa, Versicolor, and Viginica
120119
logits = tf.layers.Dense(3)(h2)
121120

122121
# class_ids will be the model prediction for the class (Iris flower type)
@@ -206,14 +205,14 @@ def my_model_fn(
206205
tf.logging.info("Prediction on test file")
207206
for prediction in predict_results:
208207
# Will print the predicted class, i.e: 0, 1, or 2 if the prediction
209-
# is Iris Sentosa, Vericolor, Virginica, respectively.
208+
# is Iris Setosa, Vericolor, Virginica, respectively.
210209
tf.logging.info("...{}".format(prediction["class_ids"]))
211210

212211
# Let create a dataset for prediction
213212
# We've taken the first 3 examples in FILE_TEST
214213
prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
215214
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
216-
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa
215+
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Setosa
217216

218217
def new_input_fn():
219218
def decode(x):
@@ -234,7 +233,7 @@ def decode(x):
234233
for idx, prediction in enumerate(predict_results):
235234
type = prediction["class_ids"] # Get the predicted class (index)
236235
if type == 0:
237-
tf.logging.info("...I think: {}, is Iris Sentosa".format(prediction_input[idx]))
236+
tf.logging.info("...I think: {}, is Iris Setosa".format(prediction_input[idx]))
238237
elif type == 1:
239238
tf.logging.info("...I think: {}, is Iris Versicolor".format(prediction_input[idx]))
240239
else:

0 commit comments

Comments
 (0)