@@ -707,6 +707,63 @@ def testWithScopes(self):
707707 self .assertAllEqual (init_value1 , var1 )
708708
709709
710+ class AssignFromValuesFnTest (tf .test .TestCase ):
711+
712+ def testNoScopes (self ):
713+ init_value0 = np .asarray ([1.0 , 3.0 , 9.0 ]).reshape ((1 , 3 , 1 ))
714+ init_value1 = np .asarray ([2.0 , 4.0 , 6.0 , 8.0 ]).reshape ((2 , 1 , 2 ))
715+
716+ with self .test_session () as sess :
717+ initializer = tf .truncated_normal_initializer (stddev = .1 )
718+ var0 = tf .contrib .framework .variable (
719+ 'my_var0' , shape = [1 , 3 , 1 ], initializer = initializer )
720+ var1 = tf .contrib .framework .variable (
721+ 'my_var1' , shape = [2 , 1 , 2 ], initializer = initializer )
722+
723+ var_names_to_values = {'my_var0' : init_value0 , 'my_var1' : init_value1 }
724+ init_fn = tf .contrib .framework .assign_from_values_fn (var_names_to_values )
725+
726+ # Initialize the variables.
727+ sess .run (tf .initialize_all_variables ())
728+
729+ # Perform the assignment.
730+ init_fn (sess )
731+
732+ # Request and test the variable values:
733+ var0 , var1 = sess .run ([var0 , var1 ])
734+ self .assertAllEqual (init_value0 , var0 )
735+ self .assertAllEqual (init_value1 , var1 )
736+
737+ def testWithScopes (self ):
738+ init_value0 = np .asarray ([1.0 , 3.0 , 9.0 ]).reshape ((1 , 3 , 1 ))
739+ init_value1 = np .asarray ([2.0 , 4.0 , 6.0 , 8.0 ]).reshape ((2 , 1 , 2 ))
740+
741+ with self .test_session () as sess :
742+ initializer = tf .truncated_normal_initializer (stddev = .1 )
743+
744+ with tf .variable_scope ('my_model/my_layer0' ):
745+ var0 = tf .contrib .framework .variable (
746+ 'my_var0' , shape = [1 , 3 , 1 ], initializer = initializer )
747+ with tf .variable_scope ('my_model/my_layer1' ):
748+ var1 = tf .contrib .framework .variable (
749+ 'my_var1' , shape = [2 , 1 , 2 ], initializer = initializer )
750+
751+ var_names_to_values = {'my_model/my_layer0/my_var0' : init_value0 ,
752+ 'my_model/my_layer1/my_var1' : init_value1 }
753+ init_fn = tf .contrib .framework .assign_from_values_fn (var_names_to_values )
754+
755+ # Initialize the variables.
756+ sess .run (tf .initialize_all_variables ())
757+
758+ # Perform the assignment.
759+ init_fn (sess )
760+
761+ # Request and test the variable values:
762+ var0 , var1 = sess .run ([var0 , var1 ])
763+ self .assertAllEqual (init_value0 , var0 )
764+ self .assertAllEqual (init_value1 , var1 )
765+
766+
710767class AssignFromCheckpointTest (tf .test .TestCase ):
711768
712769 def create_checkpoint_from_values (self , var_names_to_values , checkpoint_dir ,
@@ -810,5 +867,190 @@ def testInitFromCheckpointWithScopes(self):
810867 self .assertAllEqual (init_value0 , var0 .eval ())
811868 self .assertAllEqual (init_value1 , var1 .eval ())
812869
870+
871+ class AssignFromCheckpointFnTest (tf .test .TestCase ):
872+
873+ def create_checkpoint_from_values (self , var_names_to_values , checkpoint_dir ,
874+ global_step = None ):
875+ """Creates a checkpoint from a mapping of name to values in model_dir.
876+
877+ Args:
878+ var_names_to_values: a map from variable names to values.
879+ checkpoint_dir: the directory where the checkpoint will be saved.
880+ global_step: the global step used to save the checkpoint.
881+
882+ Returns:
883+ the model_path to the checkpoint.
884+ """
885+ var_list = []
886+ with tf .Session ('' , graph = tf .Graph ()) as sess :
887+ # Create a set of variables to save in the checkpoint.
888+ for var_name in var_names_to_values :
889+ var_value = var_names_to_values [var_name ]
890+ var_list .append (tf .Variable (var_value , name = var_name ))
891+ saver = tf .train .Saver (var_list )
892+ init_op = tf .initialize_variables (var_list )
893+ sess .run (init_op )
894+ # Save the initialized values in the file at 'checkpoint_dir'
895+ return saver .save (sess , checkpoint_dir , global_step = global_step )
896+
897+ def testLoadExistingVariables (self ):
898+ init_value0 = 10.0
899+ init_value1 = 20.0
900+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
901+
902+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
903+ with self .test_session () as sess :
904+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
905+ model_dir )
906+ var0 = tf .contrib .framework .variable ('my_var0' , shape = [])
907+ var1 = tf .contrib .framework .variable ('my_var1' , shape = [])
908+
909+ vars_to_restore = {'v0' : var0 , 'v1' : var1 }
910+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
911+ model_path , vars_to_restore )
912+
913+ # Initialize the variables.
914+ sess .run (tf .initialize_all_variables ())
915+
916+ # Perform the assignment.
917+ init_fn (sess )
918+
919+ # Request and test the variable values:
920+ self .assertEqual (init_value0 , var0 .eval ())
921+ self .assertEqual (init_value1 , var1 .eval ())
922+
923+ def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape (self ):
924+ init_value0 = [[10.0 , 11.0 ]]
925+ init_value1 = 20.0
926+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
927+
928+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
929+ with self .test_session () as sess :
930+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
931+ model_dir )
932+ var0 = tf .contrib .framework .variable ('my_var0' , shape = [2 , 1 ])
933+ var1 = tf .contrib .framework .variable ('my_var1' , shape = [])
934+
935+ vars_to_restore = {'v0' : var0 , 'v1' : var1 }
936+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
937+ model_path , vars_to_restore )
938+
939+ # Initialize the variables.
940+ sess .run (tf .initialize_all_variables ())
941+
942+ # Perform the assignment.
943+ with self .assertRaises (tf .errors .InvalidArgumentError ):
944+ init_fn (sess )
945+
946+ def testLoadExistingVariablesDifferentShapeAllowReshape (self ):
947+ init_value0 = [[10.0 , 11.0 ]]
948+ init_value1 = 20.0
949+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
950+
951+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
952+ with self .test_session () as sess :
953+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
954+ model_dir )
955+ var0 = tf .contrib .framework .variable ('my_var0' , shape = [2 , 1 ])
956+ var1 = tf .contrib .framework .variable ('my_var1' , shape = [])
957+
958+ vars_to_restore = {'v0' : var0 , 'v1' : var1 }
959+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
960+ model_path , vars_to_restore , reshape_variables = True )
961+
962+ # Initialize the variables.
963+ sess .run (tf .initialize_all_variables ())
964+
965+ # Perform the assignment.
966+ init_fn (sess )
967+
968+ # Request and test the variable values:
969+ self .assertAllEqual (np .transpose (np .array (init_value0 )), var0 .eval ())
970+ self .assertEqual (init_value1 , var1 .eval ())
971+
972+ def testNotFoundError (self ):
973+ init_value0 = 10.0
974+ init_value1 = 20.0
975+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
976+
977+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
978+ with self .test_session () as sess :
979+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
980+ model_dir )
981+ var0 = tf .contrib .framework .variable ('my_var0' , shape = [])
982+ var1 = tf .contrib .framework .variable ('my_var1' , shape = [])
983+ var2 = tf .contrib .framework .variable ('my_var2' , shape = [])
984+
985+ vars_to_restore = {'v0' : var0 , 'v1' : var1 , 'v2' : var2 }
986+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
987+ model_path ,
988+ vars_to_restore )
989+
990+ # Initialize the variables.
991+ sess .run (tf .initialize_all_variables ())
992+
993+ # Perform the assignment.
994+ with self .assertRaises (tf .errors .NotFoundError ):
995+ init_fn (sess )
996+
997+ def testMissingVariablesList (self ):
998+ init_value0 = 10.0
999+ init_value1 = 20.0
1000+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
1001+
1002+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
1003+ with self .test_session () as sess :
1004+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
1005+ model_dir )
1006+ var0 = tf .contrib .framework .variable ('v0' , shape = [])
1007+ var1 = tf .contrib .framework .variable ('v1' , shape = [])
1008+ var2 = tf .contrib .framework .variable ('v2' , shape = [])
1009+
1010+ vars_to_restore = [var0 , var1 , var2 ]
1011+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
1012+ model_path ,
1013+ vars_to_restore ,
1014+ ignore_missing_vars = True )
1015+
1016+ # Initialize the variables.
1017+ sess .run (tf .initialize_all_variables ())
1018+
1019+ # Perform the assignment.
1020+ init_fn (sess )
1021+
1022+ # Request and test the variable values:
1023+ self .assertEqual (init_value0 , var0 .eval ())
1024+ self .assertEqual (init_value1 , var1 .eval ())
1025+
1026+ def testMissingVariablesDict (self ):
1027+ init_value0 = 10.0
1028+ init_value1 = 20.0
1029+ var_names_to_values = {'v0' : init_value0 , 'v1' : init_value1 }
1030+
1031+ model_dir = os .path .join (self .get_temp_dir (), 'model' )
1032+ with self .test_session () as sess :
1033+ model_path = self .create_checkpoint_from_values (var_names_to_values ,
1034+ model_dir )
1035+ var0 = tf .contrib .framework .variable ('my_var0' , shape = [])
1036+ var1 = tf .contrib .framework .variable ('my_var1' , shape = [])
1037+ var2 = tf .contrib .framework .variable ('my_var2' , shape = [])
1038+
1039+ vars_to_restore = {'v0' : var0 , 'v1' : var1 , 'v2' : var2 }
1040+ init_fn = tf .contrib .framework .assign_from_checkpoint_fn (
1041+ model_path ,
1042+ vars_to_restore ,
1043+ ignore_missing_vars = True )
1044+
1045+ # Initialize the variables.
1046+ sess .run (tf .initialize_all_variables ())
1047+
1048+ # Perform the assignment.
1049+ init_fn (sess )
1050+
1051+ # Request and test the variable values:
1052+ self .assertEqual (init_value0 , var0 .eval ())
1053+ self .assertEqual (init_value1 , var1 .eval ())
1054+
8131055if __name__ == '__main__' :
8141056 tf .test .main ()
0 commit comments