@@ -953,7 +953,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
953953
954954 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
955955
956- def full_loop (self , ** config ):
956+ def full_loop (self , scheduler = None , ** config ):
957+ if scheduler is None :
958+ scheduler_class = self .scheduler_classes [0 ]
959+ scheduler_config = self .get_scheduler_config (** config )
960+ scheduler = scheduler_class (** scheduler_config )
961+
957962 scheduler_class = self .scheduler_classes [0 ]
958963 scheduler_config = self .get_scheduler_config (** config )
959964 scheduler = scheduler_class (** scheduler_config )
@@ -973,6 +978,25 @@ def test_timesteps(self):
973978 for timesteps in [25 , 50 , 100 , 999 , 1000 ]:
974979 self .check_over_configs (num_train_timesteps = timesteps )
975980
981+ def test_switch (self ):
982+ # make sure that iterating over schedulers with same config names gives same results
983+ # for defaults
984+ scheduler = DPMSolverSinglestepScheduler (** self .get_scheduler_config ())
985+ sample = self .full_loop (scheduler = scheduler )
986+ result_mean = torch .mean (torch .abs (sample ))
987+
988+ assert abs (result_mean .item () - 0.2791 ) < 1e-3
989+
990+ scheduler = DEISMultistepScheduler .from_config (scheduler .config )
991+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
992+ scheduler = UniPCMultistepScheduler .from_config (scheduler .config )
993+ scheduler = DPMSolverSinglestepScheduler .from_config (scheduler .config )
994+
995+ sample = self .full_loop (scheduler = scheduler )
996+ result_mean = torch .mean (torch .abs (sample ))
997+
998+ assert abs (result_mean .item () - 0.2791 ) < 1e-3
999+
9761000 def test_thresholding (self ):
9771001 self .check_over_configs (thresholding = False )
9781002 for order in [1 , 2 , 3 ]:
@@ -1130,10 +1154,11 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
11301154
11311155 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
11321156
1133- def full_loop (self , ** config ):
1134- scheduler_class = self .scheduler_classes [0 ]
1135- scheduler_config = self .get_scheduler_config (** config )
1136- scheduler = scheduler_class (** scheduler_config )
1157+ def full_loop (self , scheduler = None , ** config ):
1158+ if scheduler is None :
1159+ scheduler_class = self .scheduler_classes [0 ]
1160+ scheduler_config = self .get_scheduler_config (** config )
1161+ scheduler = scheduler_class (** scheduler_config )
11371162
11381163 num_inference_steps = 10
11391164 model = self .dummy_model ()
@@ -1244,6 +1269,25 @@ def test_full_loop_with_v_prediction(self):
12441269
12451270 assert abs (result_mean .item () - 0.2251 ) < 1e-3
12461271
1272+ def test_switch (self ):
1273+ # make sure that iterating over schedulers with same config names gives same results
1274+ # for defaults
1275+ scheduler = DPMSolverMultistepScheduler (** self .get_scheduler_config ())
1276+ sample = self .full_loop (scheduler = scheduler )
1277+ result_mean = torch .mean (torch .abs (sample ))
1278+
1279+ assert abs (result_mean .item () - 0.3301 ) < 1e-3
1280+
1281+ scheduler = DPMSolverSinglestepScheduler .from_config (scheduler .config )
1282+ scheduler = UniPCMultistepScheduler .from_config (scheduler .config )
1283+ scheduler = DEISMultistepScheduler .from_config (scheduler .config )
1284+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
1285+
1286+ sample = self .full_loop (scheduler = scheduler )
1287+ result_mean = torch .mean (torch .abs (sample ))
1288+
1289+ assert abs (result_mean .item () - 0.3301 ) < 1e-3
1290+
12471291 def test_fp16_support (self ):
12481292 scheduler_class = self .scheduler_classes [0 ]
12491293 scheduler_config = self .get_scheduler_config (thresholding = True , dynamic_thresholding_ratio = 0 )
@@ -2543,7 +2587,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
25432587
25442588 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
25452589
2546- def full_loop (self , ** config ):
2590+ def full_loop (self , scheduler = None , ** config ):
2591+ if scheduler is None :
2592+ scheduler_class = self .scheduler_classes [0 ]
2593+ scheduler_config = self .get_scheduler_config (** config )
2594+ scheduler = scheduler_class (** scheduler_config )
2595+
25472596 scheduler_class = self .scheduler_classes [0 ]
25482597 scheduler_config = self .get_scheduler_config (** config )
25492598 scheduler = scheduler_class (** scheduler_config )
@@ -2589,6 +2638,25 @@ def test_step_shape(self):
25892638 self .assertEqual (output_0 .shape , sample .shape )
25902639 self .assertEqual (output_0 .shape , output_1 .shape )
25912640
2641+ def test_switch (self ):
2642+ # make sure that iterating over schedulers with same config names gives same results
2643+ # for defaults
2644+ scheduler = DEISMultistepScheduler (** self .get_scheduler_config ())
2645+ sample = self .full_loop (scheduler = scheduler )
2646+ result_mean = torch .mean (torch .abs (sample ))
2647+
2648+ assert abs (result_mean .item () - 0.23916 ) < 1e-3
2649+
2650+ scheduler = DPMSolverSinglestepScheduler .from_config (scheduler .config )
2651+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
2652+ scheduler = UniPCMultistepScheduler .from_config (scheduler .config )
2653+ scheduler = DEISMultistepScheduler .from_config (scheduler .config )
2654+
2655+ sample = self .full_loop (scheduler = scheduler )
2656+ result_mean = torch .mean (torch .abs (sample ))
2657+
2658+ assert abs (result_mean .item () - 0.23916 ) < 1e-3
2659+
25922660 def test_timesteps (self ):
25932661 for timesteps in [25 , 50 , 100 , 999 , 1000 ]:
25942662 self .check_over_configs (num_train_timesteps = timesteps )
@@ -2742,7 +2810,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
27422810
27432811 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
27442812
2745- def full_loop (self , ** config ):
2813+ def full_loop (self , scheduler = None , ** config ):
2814+ if scheduler is None :
2815+ scheduler_class = self .scheduler_classes [0 ]
2816+ scheduler_config = self .get_scheduler_config (** config )
2817+ scheduler = scheduler_class (** scheduler_config )
2818+
27462819 scheduler_class = self .scheduler_classes [0 ]
27472820 scheduler_config = self .get_scheduler_config (** config )
27482821 scheduler = scheduler_class (** scheduler_config )
@@ -2788,6 +2861,25 @@ def test_step_shape(self):
27882861 self .assertEqual (output_0 .shape , sample .shape )
27892862 self .assertEqual (output_0 .shape , output_1 .shape )
27902863
2864+ def test_switch (self ):
2865+ # make sure that iterating over schedulers with same config names gives same results
2866+ # for defaults
2867+ scheduler = UniPCMultistepScheduler (** self .get_scheduler_config ())
2868+ sample = self .full_loop (scheduler = scheduler )
2869+ result_mean = torch .mean (torch .abs (sample ))
2870+
2871+ assert abs (result_mean .item () - 0.2521 ) < 1e-3
2872+
2873+ scheduler = DPMSolverSinglestepScheduler .from_config (scheduler .config )
2874+ scheduler = DEISMultistepScheduler .from_config (scheduler .config )
2875+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
2876+ scheduler = UniPCMultistepScheduler .from_config (scheduler .config )
2877+
2878+ sample = self .full_loop (scheduler = scheduler )
2879+ result_mean = torch .mean (torch .abs (sample ))
2880+
2881+ assert abs (result_mean .item () - 0.2521 ) < 1e-3
2882+
27912883 def test_timesteps (self ):
27922884 for timesteps in [25 , 50 , 100 , 999 , 1000 ]:
27932885 self .check_over_configs (num_train_timesteps = timesteps )
0 commit comments