Skip to content

Commit 8ea90fa

Browse files
Nathan Silbermantensorflower-gardener
authored andcommitted
Adding assign_from_checkpoint_fn and assign_from_values_fn to tf.contrib.
Change: 130962238
1 parent 75edf3f commit 8ea90fa

File tree

3 files changed

+315
-0
lines changed

3 files changed

+315
-0
lines changed

tensorflow/contrib/framework/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
@@add_model_variable
4343
@@assert_global_step
4444
@@assert_or_get_global_step
45+
@@assign_from_checkpoint
46+
@@assign_from_checkpoint_fn
47+
@@assign_from_values
48+
@@assign_from_values_fn
4549
@@create_global_step
4650
@@get_global_step
4751
@@get_or_create_global_step

tensorflow/contrib/framework/python/ops/variables.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
from tensorflow.python.ops import variable_scope
3131
from tensorflow.python.ops import variables
3232
from tensorflow.python.platform import tf_logging as logging
33+
from tensorflow.python.training import saver as tf_saver
3334

3435

3536
__all__ = ['add_model_variable',
3637
'assert_global_step',
3738
'assert_or_get_global_step',
3839
'assign_from_checkpoint',
40+
'assign_from_checkpoint_fn',
3941
'assign_from_values',
42+
'assign_from_values_fn',
4043
'create_global_step',
4144
'get_global_step',
4245
'get_or_create_global_step',
@@ -461,6 +464,28 @@ def assign_from_values(var_names_to_values):
461464
return assign_op, feed_dict
462465

463466

467+
def assign_from_values_fn(var_names_to_values):
468+
"""Returns a function that assigns specific variables from the given values.
469+
470+
This function provides a mechanism for performing assignment of variables
471+
to values in a way that does not fill the graph with large assignment values.
472+
473+
Args:
474+
var_names_to_values: A map from variable names to values.
475+
476+
Returns:
477+
A function that takes a single argument, a `tf.Session`, that applies the
478+
assignment operation.
479+
480+
Raises:
481+
ValueError: if any of the given variable names were not found.
482+
"""
483+
assign_op, feed_dict = assign_from_values(var_names_to_values)
484+
def callback(session):
485+
return session.run(assign_op, feed_dict)
486+
return callback
487+
488+
464489
# TODO(nsilberman): add flag to load exponential moving averages instead
465490
def assign_from_checkpoint(model_path, var_list):
466491
"""Creates an operation to assign specific variables from a checkpoint.
@@ -507,6 +532,50 @@ def assign_from_checkpoint(model_path, var_list):
507532
return assign_op, feed_dict
508533

509534

535+
def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
536+
reshape_variables=False):
537+
"""Returns a function that assigns specific variables from a checkpoint.
538+
539+
Args:
540+
model_path: The full path to the model checkpoint. To get latest checkpoint
541+
use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
542+
var_list: A list of `Variable` objects or a dictionary mapping names in the
543+
checkpoint to the correspoing variables to initialize. If empty or None,
544+
it would return no_op(), None.
545+
ignore_missing_vars: Boolean, if True it would ignore variables missing in
546+
the checkpoint with a warning instead of failing.
547+
reshape_variables: Boolean, if True it would automatically reshape variables
548+
which are of different shape then the ones stored in the checkpoint but
549+
which have the same number of elements.
550+
551+
Returns:
552+
A function that takes a single argument, a `tf.Session`, that applies the
553+
assignment operation.
554+
555+
Raises:
556+
ValueError: If the checkpoint specified at `model_path` is missing one of
557+
the variables in `var_list`.
558+
"""
559+
if ignore_missing_vars:
560+
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
561+
if isinstance(var_list, dict):
562+
var_dict = var_list
563+
else:
564+
var_dict = {var.op.name: var for var in var_list}
565+
available_vars = {}
566+
for var in var_dict:
567+
if reader.has_tensor(var):
568+
available_vars[var] = var_dict[var]
569+
else:
570+
logging.warning(
571+
'Variable %s missing in checkpoint %s' % (var, model_path))
572+
var_list = available_vars
573+
saver = tf_saver.Saver(var_list, reshape=reshape_variables)
574+
def callback(session):
575+
saver.restore(session, model_path)
576+
return callback
577+
578+
510579
class VariableDeviceChooser(object):
511580
"""Device chooser for variables.
512581

tensorflow/contrib/framework/python/ops/variables_test.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,63 @@ def testWithScopes(self):
707707
self.assertAllEqual(init_value1, var1)
708708

709709

710+
class AssignFromValuesFnTest(tf.test.TestCase):
711+
712+
def testNoScopes(self):
713+
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
714+
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
715+
716+
with self.test_session() as sess:
717+
initializer = tf.truncated_normal_initializer(stddev=.1)
718+
var0 = tf.contrib.framework.variable(
719+
'my_var0', shape=[1, 3, 1], initializer=initializer)
720+
var1 = tf.contrib.framework.variable(
721+
'my_var1', shape=[2, 1, 2], initializer=initializer)
722+
723+
var_names_to_values = {'my_var0': init_value0, 'my_var1': init_value1}
724+
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
725+
726+
# Initialize the variables.
727+
sess.run(tf.initialize_all_variables())
728+
729+
# Perform the assignment.
730+
init_fn(sess)
731+
732+
# Request and test the variable values:
733+
var0, var1 = sess.run([var0, var1])
734+
self.assertAllEqual(init_value0, var0)
735+
self.assertAllEqual(init_value1, var1)
736+
737+
def testWithScopes(self):
738+
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
739+
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
740+
741+
with self.test_session() as sess:
742+
initializer = tf.truncated_normal_initializer(stddev=.1)
743+
744+
with tf.variable_scope('my_model/my_layer0'):
745+
var0 = tf.contrib.framework.variable(
746+
'my_var0', shape=[1, 3, 1], initializer=initializer)
747+
with tf.variable_scope('my_model/my_layer1'):
748+
var1 = tf.contrib.framework.variable(
749+
'my_var1', shape=[2, 1, 2], initializer=initializer)
750+
751+
var_names_to_values = {'my_model/my_layer0/my_var0': init_value0,
752+
'my_model/my_layer1/my_var1': init_value1}
753+
init_fn = tf.contrib.framework.assign_from_values_fn(var_names_to_values)
754+
755+
# Initialize the variables.
756+
sess.run(tf.initialize_all_variables())
757+
758+
# Perform the assignment.
759+
init_fn(sess)
760+
761+
# Request and test the variable values:
762+
var0, var1 = sess.run([var0, var1])
763+
self.assertAllEqual(init_value0, var0)
764+
self.assertAllEqual(init_value1, var1)
765+
766+
710767
class AssignFromCheckpointTest(tf.test.TestCase):
711768

712769
def create_checkpoint_from_values(self, var_names_to_values, checkpoint_dir,
@@ -810,5 +867,190 @@ def testInitFromCheckpointWithScopes(self):
810867
self.assertAllEqual(init_value0, var0.eval())
811868
self.assertAllEqual(init_value1, var1.eval())
812869

870+
871+
class AssignFromCheckpointFnTest(tf.test.TestCase):
872+
873+
def create_checkpoint_from_values(self, var_names_to_values, checkpoint_dir,
874+
global_step=None):
875+
"""Creates a checkpoint from a mapping of name to values in model_dir.
876+
877+
Args:
878+
var_names_to_values: a map from variable names to values.
879+
checkpoint_dir: the directory where the checkpoint will be saved.
880+
global_step: the global step used to save the checkpoint.
881+
882+
Returns:
883+
the model_path to the checkpoint.
884+
"""
885+
var_list = []
886+
with tf.Session('', graph=tf.Graph()) as sess:
887+
# Create a set of variables to save in the checkpoint.
888+
for var_name in var_names_to_values:
889+
var_value = var_names_to_values[var_name]
890+
var_list.append(tf.Variable(var_value, name=var_name))
891+
saver = tf.train.Saver(var_list)
892+
init_op = tf.initialize_variables(var_list)
893+
sess.run(init_op)
894+
# Save the initialized values in the file at 'checkpoint_dir'
895+
return saver.save(sess, checkpoint_dir, global_step=global_step)
896+
897+
def testLoadExistingVariables(self):
898+
init_value0 = 10.0
899+
init_value1 = 20.0
900+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
901+
902+
model_dir = os.path.join(self.get_temp_dir(), 'model')
903+
with self.test_session() as sess:
904+
model_path = self.create_checkpoint_from_values(var_names_to_values,
905+
model_dir)
906+
var0 = tf.contrib.framework.variable('my_var0', shape=[])
907+
var1 = tf.contrib.framework.variable('my_var1', shape=[])
908+
909+
vars_to_restore = {'v0': var0, 'v1': var1}
910+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
911+
model_path, vars_to_restore)
912+
913+
# Initialize the variables.
914+
sess.run(tf.initialize_all_variables())
915+
916+
# Perform the assignment.
917+
init_fn(sess)
918+
919+
# Request and test the variable values:
920+
self.assertEqual(init_value0, var0.eval())
921+
self.assertEqual(init_value1, var1.eval())
922+
923+
def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self):
924+
init_value0 = [[10.0, 11.0]]
925+
init_value1 = 20.0
926+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
927+
928+
model_dir = os.path.join(self.get_temp_dir(), 'model')
929+
with self.test_session() as sess:
930+
model_path = self.create_checkpoint_from_values(var_names_to_values,
931+
model_dir)
932+
var0 = tf.contrib.framework.variable('my_var0', shape=[2, 1])
933+
var1 = tf.contrib.framework.variable('my_var1', shape=[])
934+
935+
vars_to_restore = {'v0': var0, 'v1': var1}
936+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
937+
model_path, vars_to_restore)
938+
939+
# Initialize the variables.
940+
sess.run(tf.initialize_all_variables())
941+
942+
# Perform the assignment.
943+
with self.assertRaises(tf.errors.InvalidArgumentError):
944+
init_fn(sess)
945+
946+
def testLoadExistingVariablesDifferentShapeAllowReshape(self):
947+
init_value0 = [[10.0, 11.0]]
948+
init_value1 = 20.0
949+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
950+
951+
model_dir = os.path.join(self.get_temp_dir(), 'model')
952+
with self.test_session() as sess:
953+
model_path = self.create_checkpoint_from_values(var_names_to_values,
954+
model_dir)
955+
var0 = tf.contrib.framework.variable('my_var0', shape=[2, 1])
956+
var1 = tf.contrib.framework.variable('my_var1', shape=[])
957+
958+
vars_to_restore = {'v0': var0, 'v1': var1}
959+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
960+
model_path, vars_to_restore, reshape_variables=True)
961+
962+
# Initialize the variables.
963+
sess.run(tf.initialize_all_variables())
964+
965+
# Perform the assignment.
966+
init_fn(sess)
967+
968+
# Request and test the variable values:
969+
self.assertAllEqual(np.transpose(np.array(init_value0)), var0.eval())
970+
self.assertEqual(init_value1, var1.eval())
971+
972+
def testNotFoundError(self):
973+
init_value0 = 10.0
974+
init_value1 = 20.0
975+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
976+
977+
model_dir = os.path.join(self.get_temp_dir(), 'model')
978+
with self.test_session() as sess:
979+
model_path = self.create_checkpoint_from_values(var_names_to_values,
980+
model_dir)
981+
var0 = tf.contrib.framework.variable('my_var0', shape=[])
982+
var1 = tf.contrib.framework.variable('my_var1', shape=[])
983+
var2 = tf.contrib.framework.variable('my_var2', shape=[])
984+
985+
vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
986+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
987+
model_path,
988+
vars_to_restore)
989+
990+
# Initialize the variables.
991+
sess.run(tf.initialize_all_variables())
992+
993+
# Perform the assignment.
994+
with self.assertRaises(tf.errors.NotFoundError):
995+
init_fn(sess)
996+
997+
def testMissingVariablesList(self):
998+
init_value0 = 10.0
999+
init_value1 = 20.0
1000+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
1001+
1002+
model_dir = os.path.join(self.get_temp_dir(), 'model')
1003+
with self.test_session() as sess:
1004+
model_path = self.create_checkpoint_from_values(var_names_to_values,
1005+
model_dir)
1006+
var0 = tf.contrib.framework.variable('v0', shape=[])
1007+
var1 = tf.contrib.framework.variable('v1', shape=[])
1008+
var2 = tf.contrib.framework.variable('v2', shape=[])
1009+
1010+
vars_to_restore = [var0, var1, var2]
1011+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
1012+
model_path,
1013+
vars_to_restore,
1014+
ignore_missing_vars=True)
1015+
1016+
# Initialize the variables.
1017+
sess.run(tf.initialize_all_variables())
1018+
1019+
# Perform the assignment.
1020+
init_fn(sess)
1021+
1022+
# Request and test the variable values:
1023+
self.assertEqual(init_value0, var0.eval())
1024+
self.assertEqual(init_value1, var1.eval())
1025+
1026+
def testMissingVariablesDict(self):
1027+
init_value0 = 10.0
1028+
init_value1 = 20.0
1029+
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
1030+
1031+
model_dir = os.path.join(self.get_temp_dir(), 'model')
1032+
with self.test_session() as sess:
1033+
model_path = self.create_checkpoint_from_values(var_names_to_values,
1034+
model_dir)
1035+
var0 = tf.contrib.framework.variable('my_var0', shape=[])
1036+
var1 = tf.contrib.framework.variable('my_var1', shape=[])
1037+
var2 = tf.contrib.framework.variable('my_var2', shape=[])
1038+
1039+
vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
1040+
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
1041+
model_path,
1042+
vars_to_restore,
1043+
ignore_missing_vars=True)
1044+
1045+
# Initialize the variables.
1046+
sess.run(tf.initialize_all_variables())
1047+
1048+
# Perform the assignment.
1049+
init_fn(sess)
1050+
1051+
# Request and test the variable values:
1052+
self.assertEqual(init_value0, var0.eval())
1053+
self.assertEqual(init_value1, var1.eval())
1054+
8131055
if __name__ == '__main__':
8141056
tf.test.main()

0 commit comments

Comments
 (0)