Skip to content

Commit 27d7a9f

Browse files
committed
Implement savedmodel in sparse classifier
1 parent 0ca00b4 commit 27d7a9f

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

dense_classifier.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,17 +528,17 @@ def inference(inputs, is_train=True):
528528
accuracy = float(correct_label_number) / label_number
529529

530530
# Compute auc
531-
expected_labels = np.array(inference_data_labels)
532-
predict_labels = prediction_softmax[:, 0]
533-
fpr, tpr, thresholds = metrics.roc_curve(expected_labels,
534-
predict_labels,
535-
pos_label=0)
531+
y_true = np.array(inference_data_labels)
532+
y_score = prediction_softmax[:, 1]
533+
fpr, tpr, thresholds = metrics.roc_curve(y_true,
534+
y_score,
535+
pos_label=1)
536536
auc = metrics.auc(fpr, tpr)
537537
logging.info("[{}] Inference accuracy: {}, auc: {}".format(
538538
end_time - start_time, accuracy, auc))
539539

540540
# Save result into the file
541-
np.savetxt(inference_result_file_name, prediction, delimiter=",")
541+
np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
542542
logging.info("Save result to file: {}".format(
543543
inference_result_file_name))
544544

sparse_classifier.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from sklearn import metrics
1212
import tensorflow as tf
1313
from tensorflow.contrib.session_bundle import exporter
14+
from tensorflow.python.saved_model import builder as saved_model_builder
15+
from tensorflow.python.saved_model import signature_constants
16+
from tensorflow.python.saved_model import signature_def_utils
17+
from tensorflow.python.saved_model import tag_constants
18+
from tensorflow.python.saved_model import utils
19+
from tensorflow.python.util import compat
1420

1521
# Define hyperparameters
1622
flags = tf.app.flags
@@ -49,6 +55,8 @@
4955
flags.DEFINE_integer("steps_to_validate", 10,
5056
"Steps to validate and print state")
5157
flags.DEFINE_string("mode", "train", "Support train, export, inference")
58+
flags.DEFINE_string("saved_model_path", "./sparse_saved_model/",
59+
"The path of the saved model")
5260
flags.DEFINE_string("model_path", "./sparse_model/", "The path of the model")
5361
flags.DEFINE_integer("model_version", 1, "The version of the model")
5462
flags.DEFINE_string("inference_test_file", "./data/a8a_test.libsvm",
@@ -391,6 +399,50 @@ def inference(sparse_ids, sparse_values, is_train=True):
391399
export_model(sess, saver, model_signature, FLAGS.model_path,
392400
FLAGS.model_version)
393401

402+
elif MODE == "savedmodel":
403+
if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT):
404+
logging.error("No checkpoint found, exit now")
405+
exit(1)
406+
407+
logging.info("Export the saved model to {}".format(
408+
FLAGS.saved_model_path))
409+
export_path_base = FLAGS.saved_model_path
410+
export_path = os.path.join(
411+
compat.as_bytes(export_path_base),
412+
compat.as_bytes(str(FLAGS.model_version)))
413+
414+
model_signature = signature_def_utils.build_signature_def(
415+
inputs={
416+
"keys": utils.build_tensor_info(keys_placeholder),
417+
"indexs": utils.build_tensor_info(sparse_index),
418+
"ids": utils.build_tensor_info(sparse_ids),
419+
"values": utils.build_tensor_info(sparse_values),
420+
"shape": utils.build_tensor_info(sparse_shape)
421+
},
422+
outputs={
423+
"keys": utils.build_tensor_info(keys),
424+
"softmax": utils.build_tensor_info(inference_softmax),
425+
"prediction": utils.build_tensor_info(inference_op)
426+
},
427+
method_name=signature_constants.PREDICT_METHOD_NAME)
428+
429+
try:
430+
builder = saved_model_builder.SavedModelBuilder(export_path)
431+
builder.add_meta_graph_and_variables(
432+
sess,
433+
[tag_constants.SERVING],
434+
signature_def_map={
435+
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
436+
model_signature,
437+
},
438+
#legacy_init_op=legacy_init_op)
439+
legacy_init_op=tf.group(tf.initialize_all_tables(),
440+
name="legacy_init_op"))
441+
442+
builder.save()
443+
except Exception as e:
444+
logging.error("Fail to export saved model, exception: {}".format(e))
445+
394446
elif MODE == "inference":
395447
if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT):
396448
logging.error("No checkpoint found, exit now")
@@ -446,7 +498,7 @@ def inference(sparse_ids, sparse_values, is_train=True):
446498
end_time - start_time, accuracy, auc))
447499

448500
# Save result into the file
449-
np.savetxt(inference_result_file_name, prediction, delimiter=",")
501+
np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",")
450502
logging.info("Save result to file: {}".format(
451503
inference_result_file_name))
452504

0 commit comments

Comments
 (0)