Skip to content

Commit 30a6b56

Browse files
k-w-wtensorflower-gardener
authored andcommitted
Automated rollback of commit 069f808
PiperOrigin-RevId: 210656847
1 parent 8012cf5 commit 30a6b56

File tree

10 files changed

+139
-634
lines changed

10 files changed

+139
-634
lines changed

tensorflow/contrib/saved_model/BUILD

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ py_library(
3636
srcs_version = "PY2AND3",
3737
visibility = ["//visibility:public"],
3838
deps = [
39-
":keras_saved_model",
4039
"//tensorflow/core:protos_all_py",
4140
"//tensorflow/python:framework_ops",
4241
"//tensorflow/python:lib",
@@ -102,33 +101,23 @@ py_library(
102101
tags = ["no_windows"],
103102
visibility = ["//visibility:public"],
104103
deps = [
105-
"//tensorflow/python:array_ops",
106-
"//tensorflow/python:framework_ops",
107104
"//tensorflow/python:lib",
108-
"//tensorflow/python:metrics",
109-
"//tensorflow/python:platform",
110-
"//tensorflow/python:saver",
111105
"//tensorflow/python:util",
112-
"//tensorflow/python/estimator",
113-
"//tensorflow/python/estimator:export",
114-
"//tensorflow/python/estimator:keras",
115-
"//tensorflow/python/estimator:model_fn",
116106
"//tensorflow/python/keras:engine",
117-
"//tensorflow/python/saved_model",
107+
"//tensorflow/python/saved_model:constants",
118108
],
119109
)
120110

121111
py_test(
122112
name = "keras_saved_model_test",
123-
size = "medium",
113+
size = "small",
124114
srcs = ["python/saved_model/keras_saved_model_test.py"],
125115
srcs_version = "PY2AND3",
126116
deps = [
127-
":keras_saved_model",
117+
":saved_model_py",
128118
"//tensorflow/python:client_testlib",
129119
"//tensorflow/python:training",
130120
"//tensorflow/python/keras",
131121
"//third_party/py/numpy",
132-
"@absl_py//absl/testing:parameterized",
133122
],
134123
)

tensorflow/contrib/saved_model/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@
2626
# pylint: disable=unused-import,wildcard-import,line-too-long
2727
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
2828
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
29-
# pylint: enable=unused-import,wildcard-import,line-too-long
29+
# pylint: enable=unused-import,widcard-import,line-too-long
3030

3131
from tensorflow.python.util.all_util import remove_undocumented
3232

33-
_allowed_symbols = [
34-
"get_signature_def_by_key",
35-
"load_keras_model",
36-
"save_keras_model"]
33+
_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
3734

3835
remove_undocumented(__name__, _allowed_symbols)

tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py

Lines changed: 27 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -20,270 +20,64 @@
2020

2121
import os
2222

23-
from tensorflow.python.client import session
24-
from tensorflow.python.estimator import keras as estimator_keras_util
25-
from tensorflow.python.estimator import model_fn as model_fn_lib
26-
from tensorflow.python.estimator.export import export as export_helpers
27-
from tensorflow.python.framework import errors
28-
from tensorflow.python.framework import ops
29-
from tensorflow.python.keras import backend as K
30-
from tensorflow.python.keras import models as models_lib
31-
from tensorflow.python.keras import optimizers
3223
from tensorflow.python.keras.models import model_from_json
3324
from tensorflow.python.lib.io import file_io
34-
from tensorflow.python.ops import variables
35-
from tensorflow.python.platform import gfile
36-
from tensorflow.python.platform import tf_logging as logging
37-
from tensorflow.python.saved_model import builder as saved_model_builder
3825
from tensorflow.python.saved_model import constants
39-
from tensorflow.python.saved_model import utils_impl as saved_model_utils
40-
from tensorflow.python.training import saver as saver_lib
41-
from tensorflow.python.training.checkpointable import util as checkpointable_utils
4226
from tensorflow.python.util import compat
4327

4428

45-
def save_keras_model(
46-
model, saved_model_path, custom_objects=None, as_text=None):
29+
def save_model(model, saved_model_path):
4730
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
4831
49-
`save_model` generates new files/folders under the `saved_model_path` folder:
32+
`save_model` generates such files/folders under the `saved_model_path` folder:
5033
1) an asset folder containing the json string of the model's
51-
configuration (topology).
34+
configuration(topology).
5235
2) a checkpoint containing the model weights.
53-
3) a saved_model.pb file containing the model's MetaGraphs. The prediction
54-
graph is always exported. The evaluaton and training graphs are exported
55-
if the following conditions are met:
56-
- Evaluation: model loss is defined.
57-
- Training: model is compiled with an optimizer defined under `tf.train`.
58-
This is because `tf.keras.optimizers.Optimizer` instances cannot be
59-
saved to checkpoints.
6036
61-
Model Requirements:
62-
- Model must be a sequential model or functional model. Subclassed models can
63-
not be saved via this function, unless you provide an implementation for
64-
get_config() and from_config().
65-
- All variables must be saveable by the model. In general, this condition is
66-
met through the use of layers defined in the keras library. However,
67-
there is currently a bug with variables created in Lambda layer functions
68-
not being saved correctly (see
69-
https://github.com/keras-team/keras/issues/9740).
70-
71-
Note that each mode is exported in separate graphs, so different modes do not
72-
share variables. To use the train graph with evaluation or prediction graphs,
73-
create a new checkpoint if variable values have been updated.
37+
Note that subclassed models can not be saved via this function, unless you
38+
provide an implementation for get_config() and from_config().
39+
Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
40+
saved to checkpoints. Use optimizers from `tf.train`.
7441
7542
Args:
7643
model: A `tf.keras.Model` to be saved.
7744
saved_model_path: a string specifying the path to the SavedModel directory.
78-
The SavedModel will be saved to a timestamped folder created within this
79-
directory.
80-
custom_objects: Optional dictionary mapping string names to custom classes
81-
or functions (e.g. custom loss functions).
82-
as_text: whether to write the `SavedModel` proto in text format.
83-
84-
Returns:
85-
String path to the SavedModel folder, a subdirectory of `saved_model_path`.
8645
8746
Raises:
8847
NotImplementedError: If the passed in model is a subclassed model.
8948
"""
9049
if not model._is_graph_network:
9150
raise NotImplementedError
9251

93-
export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
94-
temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
95-
96-
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
97-
98-
# Manually save variables to export them in an object-based checkpoint. This
99-
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
100-
# named-based checkpoint.
101-
# TODO(b/113134168): Add fn to Builder to save with object-based saver.
102-
# TODO(b/113178242): This should only export the model json structure. Only
103-
# one save is needed once the weights can be copied from the model to clone.
104-
checkpoint_path = _export_model_json_and_variables(model, temp_export_dir)
105-
106-
# Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
107-
# Keras models and `Estimator`s are exported with the same format.
108-
# Every time a mode is exported, the code checks to see if new variables have
109-
# been created (e.g. optimizer slot variables). If that is the case, the
110-
# checkpoint is re-saved to include the new variables.
111-
export_args = {'builder': builder,
112-
'model': model,
113-
'custom_objects': custom_objects,
114-
'checkpoint_path': checkpoint_path}
115-
116-
has_saved_vars = False
117-
if model.optimizer:
118-
if isinstance(model.optimizer, optimizers.TFOptimizer):
119-
_export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args)
120-
has_saved_vars = True
121-
_export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args)
122-
else:
123-
logging.warning(
124-
'Model was compiled with an optimizer, but the optimizer is not from '
125-
'`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
126-
'graph was exported. The train and evaluate graphs were not added to '
127-
'the SavedModel.')
128-
_export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args)
129-
130-
builder.save(as_text)
131-
132-
gfile.Rename(temp_export_dir, export_dir)
133-
return export_dir
52+
# save model configuration as a json string under assets folder.
53+
model_json = model.to_json()
54+
assets_destination_dir = os.path.join(
55+
compat.as_bytes(saved_model_path),
56+
compat.as_bytes(constants.ASSETS_DIRECTORY))
13457

58+
if not file_io.file_exists(assets_destination_dir):
59+
file_io.recursive_create_dir(assets_destination_dir)
13560

136-
def _export_model_json_and_variables(model, saved_model_path):
137-
"""Save model variables and json structure into SavedModel subdirectories."""
138-
# Save model configuration as a json string under assets folder.
139-
model_json = model.to_json()
14061
model_json_filepath = os.path.join(
141-
saved_model_utils.get_or_create_assets_dir(saved_model_path),
142-
compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
62+
compat.as_bytes(assets_destination_dir),
63+
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
14364
file_io.write_string_to_file(model_json_filepath, model_json)
14465

145-
# Save model weights in checkpoint format under variables folder.
146-
saved_model_utils.get_or_create_variables_dir(saved_model_path)
147-
checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
148-
model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
149-
return checkpoint_prefix
150-
151-
152-
def _get_var_list(model):
153-
"""Return list of all checkpointed saveable objects in the model."""
154-
return checkpointable_utils.named_saveables(model)
155-
156-
157-
def _export_mode(
158-
mode, has_saved_vars, builder, model, custom_objects, checkpoint_path):
159-
"""Export a model, and optionally save new vars from the clone model.
160-
161-
Args:
162-
mode: A `tf.estimator.ModeKeys` string.
163-
has_saved_vars: A `boolean` indicating whether the SavedModel has already
164-
exported variables.
165-
builder: A `SavedModelBuilder` object.
166-
model: A `tf.keras.Model` object.
167-
custom_objects: A dictionary mapping string names to custom classes
168-
or functions.
169-
checkpoint_path: String path to checkpoint.
170-
171-
Raises:
172-
ValueError: If the train/eval mode is being exported, but the model does
173-
not have an optimizer.
174-
"""
175-
compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
176-
if compile_clone and not model.optimizer:
177-
raise ValueError(
178-
'Model does not have an optimizer. Cannot export mode %s' % mode)
179-
180-
model_graph = ops.get_default_graph()
181-
with ops.Graph().as_default() as g:
182-
183-
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
184-
185-
# Clone the model into blank graph. This will create placeholders for inputs
186-
# and targets.
187-
clone = models_lib.clone_and_build_model(
188-
model, custom_objects=custom_objects, compile_clone=compile_clone)
189-
190-
# Make sure that iterations variable is added to the global step collection,
191-
# to ensure that, when the SavedModel graph is loaded, the iterations
192-
# variable is returned by `tf.train.get_global_step()`. This is required for
193-
# compatibility with the SavedModelEstimator.
194-
if compile_clone:
195-
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
196-
197-
# Extract update and train ops from train/test/predict functions.
198-
if mode == model_fn_lib.ModeKeys.TRAIN:
199-
clone._make_train_function()
200-
builder._add_train_op(clone.train_function.updates_op)
201-
elif mode == model_fn_lib.ModeKeys.EVAL:
202-
clone._make_test_function()
203-
else:
204-
clone._make_predict_function()
205-
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
206-
207-
clone_var_list = checkpointable_utils.named_saveables(clone)
208-
209-
with session.Session().as_default():
210-
if has_saved_vars:
211-
# Confirm all variables in the clone have an entry in the checkpoint.
212-
status = clone.load_weights(checkpoint_path)
213-
status.assert_existing_objects_matched()
214-
else:
215-
# Confirm that variables between the clone and model match up exactly,
216-
# not counting optimizer objects. Optimizer objects are ignored because
217-
# if the model has not trained, the slot variables will not have been
218-
# created yet.
219-
# TODO(b/113179535): Replace with checkpointable equivalence.
220-
_assert_same_non_optimizer_objects(model, model_graph, clone, g)
221-
222-
# TODO(b/113178242): Use value transfer for checkpointable objects.
223-
clone.load_weights(checkpoint_path)
224-
225-
# Add graph and variables to SavedModel.
226-
# TODO(b/113134168): Switch to add_meta_graph_and_variables.
227-
clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
228-
builder._has_saved_variables = True
229-
230-
# Add graph to the SavedModel builder.
231-
builder.add_meta_graph(
232-
model_fn_lib.EXPORT_TAG_MAP[mode],
233-
signature_def_map=_create_signature_def_map(clone, mode),
234-
saver=saver_lib.Saver(clone_var_list),
235-
main_op=variables.local_variables_initializer())
236-
return None
237-
238-
239-
def _create_signature_def_map(model, mode):
240-
"""Create a SignatureDef map from a Keras model."""
241-
inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
242-
if model.optimizer:
243-
targets_dict = {x.name.split(':')[0]: x
244-
for x in model.targets if x is not None}
245-
inputs_dict.update(targets_dict)
246-
outputs_dict = {name: x
247-
for name, x in zip(model.output_names, model.outputs)}
248-
export_outputs = model_fn_lib.export_outputs_for_mode(
249-
mode,
250-
predictions=outputs_dict,
251-
loss=model.total_loss if model.optimizer else None,
252-
metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model))
253-
return export_helpers.build_all_signature_defs(
254-
inputs_dict,
255-
export_outputs=export_outputs,
256-
serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))
257-
258-
259-
def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):
260-
"""Assert model and clone contain the same checkpointable objects."""
261-
262-
def get_non_optimizer_objects(m, g):
263-
"""Gather set of model and optimizer checkpointable objects."""
264-
# Set default graph because optimizer.variables() returns optimizer
265-
# variables defined in the default graph.
266-
with g.as_default():
267-
all_objects = set(checkpointable_utils.list_objects(m))
268-
optimizer_and_variables = set()
269-
for obj in all_objects:
270-
if isinstance(obj, optimizers.TFOptimizer):
271-
optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
272-
optimizer_and_variables.update(set(obj.optimizer.variables()))
273-
return all_objects - optimizer_and_variables
66+
# save model weights in checkpoint format.
67+
checkpoint_destination_dir = os.path.join(
68+
compat.as_bytes(saved_model_path),
69+
compat.as_bytes(constants.VARIABLES_DIRECTORY))
27470

275-
model_objects = get_non_optimizer_objects(model, model_graph)
276-
clone_objects = get_non_optimizer_objects(clone, clone_graph)
71+
if not file_io.file_exists(checkpoint_destination_dir):
72+
file_io.recursive_create_dir(checkpoint_destination_dir)
27773

278-
if len(model_objects) != len(clone_objects):
279-
raise errors.InternalError(
280-
None, None,
281-
'Model and clone must use the same variables.'
282-
'\n\tModel variables: %s\n\t Clone variables: %s'
283-
% (model_objects, clone_objects))
74+
checkpoint_prefix = os.path.join(
75+
compat.as_text(checkpoint_destination_dir),
76+
compat.as_text(constants.VARIABLES_FILENAME))
77+
model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
28478

28579

286-
def load_keras_model(saved_model_path):
80+
def load_model(saved_model_path):
28781
"""Load a keras.Model from SavedModel.
28882
28983
load_model reinstantiates model state by:

0 commit comments

Comments
 (0)