Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions samples/outreach/blogs/blog_custom_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# ==============================================================================

# This is the complete code for the following blogpost:
# https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
# (https://goo.gl/Ujm2Ep)
# https://developers.googleblog.com/2017/12/creating-custom-estimators-in-tensorflow.html

import tensorflow as tf
import os
Expand Down Expand Up @@ -116,7 +115,7 @@ def my_model_fn(
h2 = tf.layers.Dense(10, activation=tf.nn.relu)(h1)

# Output 'logits' layer is three number = probability distribution
# between Iris Sentosa, Versicolor, and Viginica
# between Iris Setosa, Versicolor, and Viginica
logits = tf.layers.Dense(3)(h2)

# class_ids will be the model prediction for the class (Iris flower type)
Expand Down Expand Up @@ -206,14 +205,14 @@ def my_model_fn(
tf.logging.info("Prediction on test file")
for prediction in predict_results:
# Will print the predicted class, i.e: 0, 1, or 2 if the prediction
# is Iris Sentosa, Vericolor, Virginica, respectively.
# is Iris Setosa, Vericolor, Virginica, respectively.
tf.logging.info("...{}".format(prediction["class_ids"]))

# Let create a dataset for prediction
# We've taken the first 3 examples in FILE_TEST
prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Setosa

def new_input_fn():
def decode(x):
Expand All @@ -234,7 +233,7 @@ def decode(x):
for idx, prediction in enumerate(predict_results):
type = prediction["class_ids"] # Get the predicted class (index)
if type == 0:
tf.logging.info("...I think: {}, is Iris Sentosa".format(prediction_input[idx]))
tf.logging.info("...I think: {}, is Iris Setosa".format(prediction_input[idx]))
elif type == 1:
tf.logging.info("...I think: {}, is Iris Versicolor".format(prediction_input[idx]))
else:
Expand Down