Skip to content

Commit 5d550cf

Browse files
Make sure that DEIS, DPM and UniPC can correctly be switched in & out (huggingface#2595)
* [Schedulers] Correct config changing * uP * add tests
1 parent 24d624a commit 5d550cf

File tree

5 files changed

+107
-14
lines changed

5 files changed

+107
-14
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ def __init__(
154154
# settings for DEIS
155155
if algorithm_type not in ["deis"]:
156156
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
157-
algorithm_type = "deis"
157+
self.register_to_config(algorithm_type="deis")
158158
else:
159159
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
160160

161161
if solver_type not in ["logrho"]:
162162
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
163-
solver_type = "logrho"
163+
self.register_to_config(solver_type="logrho")
164164
else:
165165
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
166166

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,13 @@ def __init__(
165165
# settings for DPM-Solver
166166
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
167167
if algorithm_type == "deis":
168-
algorithm_type = "dpmsolver++"
168+
self.register_to_config(algorithm_type="dpmsolver++")
169169
else:
170170
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
171+
171172
if solver_type not in ["midpoint", "heun"]:
172173
if solver_type in ["logrho", "bh1", "bh2"]:
173-
solver_type = "midpoint"
174+
self.register_to_config(solver_type="midpoint")
174175
else:
175176
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
176177

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def __init__(
164164
# settings for DPM-Solver
165165
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
166166
if algorithm_type == "deis":
167-
algorithm_type = "dpmsolver++"
167+
self.register_to_config(algorithm_type="dpmsolver++")
168168
else:
169169
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
170170
if solver_type not in ["midpoint", "heun"]:
171171
if solver_type in ["logrho", "bh1", "bh2"]:
172-
solver_type = "midpoint"
172+
self.register_to_config(solver_type="midpoint")
173173
else:
174174
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
175175

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168

169169
if solver_type not in ["bh1", "bh2"]:
170170
if solver_type in ["midpoint", "heun", "logrho"]:
171-
solver_type = "bh1"
171+
self.register_to_config(solver_type="bh1")
172172
else:
173173
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
174174

tests/test_scheduler.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
953953

954954
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
955955

956-
def full_loop(self, **config):
956+
def full_loop(self, scheduler=None, **config):
957+
if scheduler is None:
958+
scheduler_class = self.scheduler_classes[0]
959+
scheduler_config = self.get_scheduler_config(**config)
960+
scheduler = scheduler_class(**scheduler_config)
961+
957962
scheduler_class = self.scheduler_classes[0]
958963
scheduler_config = self.get_scheduler_config(**config)
959964
scheduler = scheduler_class(**scheduler_config)
@@ -973,6 +978,25 @@ def test_timesteps(self):
973978
for timesteps in [25, 50, 100, 999, 1000]:
974979
self.check_over_configs(num_train_timesteps=timesteps)
975980

981+
def test_switch(self):
982+
# make sure that iterating over schedulers with same config names gives same results
983+
# for defaults
984+
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
985+
sample = self.full_loop(scheduler=scheduler)
986+
result_mean = torch.mean(torch.abs(sample))
987+
988+
assert abs(result_mean.item() - 0.2791) < 1e-3
989+
990+
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
991+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
992+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
993+
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
994+
995+
sample = self.full_loop(scheduler=scheduler)
996+
result_mean = torch.mean(torch.abs(sample))
997+
998+
assert abs(result_mean.item() - 0.2791) < 1e-3
999+
9761000
def test_thresholding(self):
9771001
self.check_over_configs(thresholding=False)
9781002
for order in [1, 2, 3]:
@@ -1130,10 +1154,11 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
11301154

11311155
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
11321156

1133-
def full_loop(self, **config):
1134-
scheduler_class = self.scheduler_classes[0]
1135-
scheduler_config = self.get_scheduler_config(**config)
1136-
scheduler = scheduler_class(**scheduler_config)
1157+
def full_loop(self, scheduler=None, **config):
1158+
if scheduler is None:
1159+
scheduler_class = self.scheduler_classes[0]
1160+
scheduler_config = self.get_scheduler_config(**config)
1161+
scheduler = scheduler_class(**scheduler_config)
11371162

11381163
num_inference_steps = 10
11391164
model = self.dummy_model()
@@ -1244,6 +1269,25 @@ def test_full_loop_with_v_prediction(self):
12441269

12451270
assert abs(result_mean.item() - 0.2251) < 1e-3
12461271

1272+
def test_switch(self):
1273+
# make sure that iterating over schedulers with same config names gives same results
1274+
# for defaults
1275+
scheduler = DPMSolverMultistepScheduler(**self.get_scheduler_config())
1276+
sample = self.full_loop(scheduler=scheduler)
1277+
result_mean = torch.mean(torch.abs(sample))
1278+
1279+
assert abs(result_mean.item() - 0.3301) < 1e-3
1280+
1281+
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
1282+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
1283+
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
1284+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
1285+
1286+
sample = self.full_loop(scheduler=scheduler)
1287+
result_mean = torch.mean(torch.abs(sample))
1288+
1289+
assert abs(result_mean.item() - 0.3301) < 1e-3
1290+
12471291
def test_fp16_support(self):
12481292
scheduler_class = self.scheduler_classes[0]
12491293
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
@@ -2543,7 +2587,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
25432587

25442588
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
25452589

2546-
def full_loop(self, **config):
2590+
def full_loop(self, scheduler=None, **config):
2591+
if scheduler is None:
2592+
scheduler_class = self.scheduler_classes[0]
2593+
scheduler_config = self.get_scheduler_config(**config)
2594+
scheduler = scheduler_class(**scheduler_config)
2595+
25472596
scheduler_class = self.scheduler_classes[0]
25482597
scheduler_config = self.get_scheduler_config(**config)
25492598
scheduler = scheduler_class(**scheduler_config)
@@ -2589,6 +2638,25 @@ def test_step_shape(self):
25892638
self.assertEqual(output_0.shape, sample.shape)
25902639
self.assertEqual(output_0.shape, output_1.shape)
25912640

2641+
def test_switch(self):
2642+
# make sure that iterating over schedulers with same config names gives same results
2643+
# for defaults
2644+
scheduler = DEISMultistepScheduler(**self.get_scheduler_config())
2645+
sample = self.full_loop(scheduler=scheduler)
2646+
result_mean = torch.mean(torch.abs(sample))
2647+
2648+
assert abs(result_mean.item() - 0.23916) < 1e-3
2649+
2650+
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
2651+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
2652+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
2653+
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
2654+
2655+
sample = self.full_loop(scheduler=scheduler)
2656+
result_mean = torch.mean(torch.abs(sample))
2657+
2658+
assert abs(result_mean.item() - 0.23916) < 1e-3
2659+
25922660
def test_timesteps(self):
25932661
for timesteps in [25, 50, 100, 999, 1000]:
25942662
self.check_over_configs(num_train_timesteps=timesteps)
@@ -2742,7 +2810,12 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
27422810

27432811
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
27442812

2745-
def full_loop(self, **config):
2813+
def full_loop(self, scheduler=None, **config):
2814+
if scheduler is None:
2815+
scheduler_class = self.scheduler_classes[0]
2816+
scheduler_config = self.get_scheduler_config(**config)
2817+
scheduler = scheduler_class(**scheduler_config)
2818+
27462819
scheduler_class = self.scheduler_classes[0]
27472820
scheduler_config = self.get_scheduler_config(**config)
27482821
scheduler = scheduler_class(**scheduler_config)
@@ -2788,6 +2861,25 @@ def test_step_shape(self):
27882861
self.assertEqual(output_0.shape, sample.shape)
27892862
self.assertEqual(output_0.shape, output_1.shape)
27902863

2864+
def test_switch(self):
2865+
# make sure that iterating over schedulers with same config names gives same results
2866+
# for defaults
2867+
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
2868+
sample = self.full_loop(scheduler=scheduler)
2869+
result_mean = torch.mean(torch.abs(sample))
2870+
2871+
assert abs(result_mean.item() - 0.2521) < 1e-3
2872+
2873+
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
2874+
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
2875+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
2876+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
2877+
2878+
sample = self.full_loop(scheduler=scheduler)
2879+
result_mean = torch.mean(torch.abs(sample))
2880+
2881+
assert abs(result_mean.item() - 0.2521) < 1e-3
2882+
27912883
def test_timesteps(self):
27922884
for timesteps in [25, 50, 100, 999, 1000]:
27932885
self.check_over_configs(num_train_timesteps=timesteps)

0 commit comments

Comments
 (0)