|
11 | 11 | from sklearn import metrics
|
12 | 12 | import tensorflow as tf
|
13 | 13 | from tensorflow.contrib.session_bundle import exporter
|
| 14 | +from tensorflow.python.saved_model import builder as saved_model_builder |
| 15 | +from tensorflow.python.saved_model import signature_constants |
| 16 | +from tensorflow.python.saved_model import signature_def_utils |
| 17 | +from tensorflow.python.saved_model import tag_constants |
| 18 | +from tensorflow.python.saved_model import utils |
| 19 | +from tensorflow.python.util import compat |
14 | 20 |
|
15 | 21 | # Define hyperparameters
|
16 | 22 | flags = tf.app.flags
|
|
49 | 55 | flags.DEFINE_integer("steps_to_validate", 10,
|
50 | 56 | "Steps to validate and print state")
|
51 | 57 | flags.DEFINE_string("mode", "train", "Support train, export, inference")
|
| 58 | +flags.DEFINE_string("saved_model_path", "./sparse_saved_model/", |
| 59 | + "The path of the saved model") |
52 | 60 | flags.DEFINE_string("model_path", "./sparse_model/", "The path of the model")
|
53 | 61 | flags.DEFINE_integer("model_version", 1, "The version of the model")
|
54 | 62 | flags.DEFINE_string("inference_test_file", "./data/a8a_test.libsvm",
|
@@ -391,6 +399,50 @@ def inference(sparse_ids, sparse_values, is_train=True):
|
391 | 399 | export_model(sess, saver, model_signature, FLAGS.model_path,
|
392 | 400 | FLAGS.model_version)
|
393 | 401 |
|
| 402 | + elif MODE == "savedmodel": |
| 403 | + if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT): |
| 404 | + logging.error("No checkpoint found, exit now") |
| 405 | + exit(1) |
| 406 | + |
| 407 | + logging.info("Export the saved model to {}".format( |
| 408 | + FLAGS.saved_model_path)) |
| 409 | + export_path_base = FLAGS.saved_model_path |
| 410 | + export_path = os.path.join( |
| 411 | + compat.as_bytes(export_path_base), |
| 412 | + compat.as_bytes(str(FLAGS.model_version))) |
| 413 | + |
| 414 | + model_signature = signature_def_utils.build_signature_def( |
| 415 | + inputs={ |
| 416 | + "keys": utils.build_tensor_info(keys_placeholder), |
| 417 | + "indexs": utils.build_tensor_info(sparse_index), |
| 418 | + "ids": utils.build_tensor_info(sparse_ids), |
| 419 | + "values": utils.build_tensor_info(sparse_values), |
| 420 | + "shape": utils.build_tensor_info(sparse_shape) |
| 421 | + }, |
| 422 | + outputs={ |
| 423 | + "keys": utils.build_tensor_info(keys), |
| 424 | + "softmax": utils.build_tensor_info(inference_softmax), |
| 425 | + "prediction": utils.build_tensor_info(inference_op) |
| 426 | + }, |
| 427 | + method_name=signature_constants.PREDICT_METHOD_NAME) |
| 428 | + |
| 429 | + try: |
| 430 | + builder = saved_model_builder.SavedModelBuilder(export_path) |
| 431 | + builder.add_meta_graph_and_variables( |
| 432 | + sess, |
| 433 | + [tag_constants.SERVING], |
| 434 | + signature_def_map={ |
| 435 | + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: |
| 436 | + model_signature, |
| 437 | + }, |
| 438 | + #legacy_init_op=legacy_init_op) |
| 439 | + legacy_init_op=tf.group(tf.initialize_all_tables(), |
| 440 | + name="legacy_init_op")) |
| 441 | + |
| 442 | + builder.save() |
| 443 | + except Exception as e: |
| 444 | + logging.error("Fail to export saved model, exception: {}".format(e)) |
| 445 | + |
394 | 446 | elif MODE == "inference":
|
395 | 447 | if not restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT):
|
396 | 448 | logging.error("No checkpoint found, exit now")
|
@@ -446,7 +498,7 @@ def inference(sparse_ids, sparse_values, is_train=True):
|
446 | 498 | end_time - start_time, accuracy, auc))
|
447 | 499 |
|
448 | 500 | # Save result into the file
|
449 |
| - np.savetxt(inference_result_file_name, prediction, delimiter=",") |
| 501 | + np.savetxt(inference_result_file_name, prediction_softmax, delimiter=",") |
450 | 502 | logging.info("Save result to file: {}".format(
|
451 | 503 | inference_result_file_name))
|
452 | 504 |
|
|
0 commit comments