Skip to content

Commit 8e74efa

Browse files
Add Singlestep DPM-Solver (singlestep high-order schedulers) (huggingface#1442)
* add singlestep dpmsolver * fix a style typo * fix a style typo * add docs * finish Co-authored-by: Patrick von Platen <[email protected]>
1 parent 6a7f1f0 commit 8e74efa

File tree

7 files changed

+800
-0
lines changed

7 files changed

+800
-0
lines changed

docs/source/api/schedulers.mdx

+6
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
7070

7171
[[autodoc]] DDPMScheduler
7272

73+
#### Singlestep DPM-Solver
74+
75+
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
76+
77+
[[autodoc]] DPMSolverSinglestepScheduler
78+
7379
#### Multistep DPM-Solver
7480

7581
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).

src/diffusers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DDIMScheduler,
4545
DDPMScheduler,
4646
DPMSolverMultistepScheduler,
47+
DPMSolverSinglestepScheduler,
4748
EulerAncestralDiscreteScheduler,
4849
EulerDiscreteScheduler,
4950
HeunDiscreteScheduler,

src/diffusers/schedulers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .scheduling_ddim import DDIMScheduler
2121
from .scheduling_ddpm import DDPMScheduler
2222
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
23+
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
2324
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
2425
from .scheduling_euler_discrete import EulerDiscreteScheduler
2526
from .scheduling_heun_discrete import HeunDiscreteScheduler

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

+599
Large diffs are not rendered by default.

src/diffusers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"HeunDiscreteScheduler",
9191
"EulerAncestralDiscreteScheduler",
9292
"DPMSolverMultistepScheduler",
93+
"DPMSolverSinglestepScheduler",
9394
]
9495

9596

src/diffusers/utils/dummy_pt_objects.py

+15
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
362362
requires_backends(cls, ["torch"])
363363

364364

365+
class DPMSolverSinglestepScheduler(metaclass=DummyObject):
366+
_backends = ["torch"]
367+
368+
def __init__(self, *args, **kwargs):
369+
requires_backends(self, ["torch"])
370+
371+
@classmethod
372+
def from_config(cls, *args, **kwargs):
373+
requires_backends(cls, ["torch"])
374+
375+
@classmethod
376+
def from_pretrained(cls, *args, **kwargs):
377+
requires_backends(cls, ["torch"])
378+
379+
365380
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
366381
_backends = ["torch"]
367382

tests/test_scheduler.py

+177
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DDIMScheduler,
2929
DDPMScheduler,
3030
DPMSolverMultistepScheduler,
31+
DPMSolverSinglestepScheduler,
3132
EulerAncestralDiscreteScheduler,
3233
EulerDiscreteScheduler,
3334
HeunDiscreteScheduler,
@@ -870,6 +871,182 @@ def test_full_loop_with_no_set_alpha_to_one(self):
870871
assert abs(result_mean.item() - 0.1941) < 1e-3
871872

872873

874+
class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
875+
scheduler_classes = (DPMSolverSinglestepScheduler,)
876+
forward_default_kwargs = (("num_inference_steps", 25),)
877+
878+
def get_scheduler_config(self, **kwargs):
879+
config = {
880+
"num_train_timesteps": 1000,
881+
"beta_start": 0.0001,
882+
"beta_end": 0.02,
883+
"beta_schedule": "linear",
884+
"solver_order": 2,
885+
"prediction_type": "epsilon",
886+
"thresholding": False,
887+
"sample_max_value": 1.0,
888+
"algorithm_type": "dpmsolver++",
889+
"solver_type": "midpoint",
890+
}
891+
892+
config.update(**kwargs)
893+
return config
894+
895+
def check_over_configs(self, time_step=0, **config):
896+
kwargs = dict(self.forward_default_kwargs)
897+
num_inference_steps = kwargs.pop("num_inference_steps", None)
898+
sample = self.dummy_sample
899+
residual = 0.1 * sample
900+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
901+
902+
for scheduler_class in self.scheduler_classes:
903+
scheduler_config = self.get_scheduler_config(**config)
904+
scheduler = scheduler_class(**scheduler_config)
905+
scheduler.set_timesteps(num_inference_steps)
906+
# copy over dummy past residuals
907+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
908+
909+
with tempfile.TemporaryDirectory() as tmpdirname:
910+
scheduler.save_config(tmpdirname)
911+
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
912+
new_scheduler.set_timesteps(num_inference_steps)
913+
# copy over dummy past residuals
914+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
915+
916+
output, new_output = sample, sample
917+
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
918+
output = scheduler.step(residual, t, output, **kwargs).prev_sample
919+
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
920+
921+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
922+
923+
def test_from_save_pretrained(self):
924+
pass
925+
926+
def check_over_forward(self, time_step=0, **forward_kwargs):
927+
kwargs = dict(self.forward_default_kwargs)
928+
num_inference_steps = kwargs.pop("num_inference_steps", None)
929+
sample = self.dummy_sample
930+
residual = 0.1 * sample
931+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
932+
933+
for scheduler_class in self.scheduler_classes:
934+
scheduler_config = self.get_scheduler_config()
935+
scheduler = scheduler_class(**scheduler_config)
936+
scheduler.set_timesteps(num_inference_steps)
937+
938+
# copy over dummy past residuals (must be after setting timesteps)
939+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
940+
941+
with tempfile.TemporaryDirectory() as tmpdirname:
942+
scheduler.save_config(tmpdirname)
943+
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
944+
# copy over dummy past residuals
945+
new_scheduler.set_timesteps(num_inference_steps)
946+
947+
# copy over dummy past residual (must be after setting timesteps)
948+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
949+
950+
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
951+
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
952+
953+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
954+
955+
def full_loop(self, **config):
956+
scheduler_class = self.scheduler_classes[0]
957+
scheduler_config = self.get_scheduler_config(**config)
958+
scheduler = scheduler_class(**scheduler_config)
959+
960+
num_inference_steps = 10
961+
model = self.dummy_model()
962+
sample = self.dummy_sample_deter
963+
scheduler.set_timesteps(num_inference_steps)
964+
965+
for i, t in enumerate(scheduler.timesteps):
966+
residual = model(sample, t)
967+
sample = scheduler.step(residual, t, sample).prev_sample
968+
969+
return sample
970+
971+
def test_timesteps(self):
972+
for timesteps in [25, 50, 100, 999, 1000]:
973+
self.check_over_configs(num_train_timesteps=timesteps)
974+
975+
def test_thresholding(self):
976+
self.check_over_configs(thresholding=False)
977+
for order in [1, 2, 3]:
978+
for solver_type in ["midpoint", "heun"]:
979+
for threshold in [0.5, 1.0, 2.0]:
980+
for prediction_type in ["epsilon", "sample"]:
981+
self.check_over_configs(
982+
thresholding=True,
983+
prediction_type=prediction_type,
984+
sample_max_value=threshold,
985+
algorithm_type="dpmsolver++",
986+
solver_order=order,
987+
solver_type=solver_type,
988+
)
989+
990+
def test_prediction_type(self):
991+
for prediction_type in ["epsilon", "v_prediction"]:
992+
self.check_over_configs(prediction_type=prediction_type)
993+
994+
def test_solver_order_and_type(self):
995+
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
996+
for solver_type in ["midpoint", "heun"]:
997+
for order in [1, 2, 3]:
998+
for prediction_type in ["epsilon", "sample"]:
999+
self.check_over_configs(
1000+
solver_order=order,
1001+
solver_type=solver_type,
1002+
prediction_type=prediction_type,
1003+
algorithm_type=algorithm_type,
1004+
)
1005+
sample = self.full_loop(
1006+
solver_order=order,
1007+
solver_type=solver_type,
1008+
prediction_type=prediction_type,
1009+
algorithm_type=algorithm_type,
1010+
)
1011+
assert not torch.isnan(sample).any(), "Samples have nan numbers"
1012+
1013+
def test_lower_order_final(self):
1014+
self.check_over_configs(lower_order_final=True)
1015+
self.check_over_configs(lower_order_final=False)
1016+
1017+
def test_inference_steps(self):
1018+
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
1019+
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
1020+
1021+
def test_full_loop_no_noise(self):
1022+
sample = self.full_loop()
1023+
result_mean = torch.mean(torch.abs(sample))
1024+
1025+
assert abs(result_mean.item() - 0.2791) < 1e-3
1026+
1027+
def test_full_loop_with_v_prediction(self):
1028+
sample = self.full_loop(prediction_type="v_prediction")
1029+
result_mean = torch.mean(torch.abs(sample))
1030+
1031+
assert abs(result_mean.item() - 0.1453) < 1e-3
1032+
1033+
def test_fp16_support(self):
1034+
scheduler_class = self.scheduler_classes[0]
1035+
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
1036+
scheduler = scheduler_class(**scheduler_config)
1037+
1038+
num_inference_steps = 10
1039+
model = self.dummy_model()
1040+
sample = self.dummy_sample_deter.half()
1041+
scheduler.set_timesteps(num_inference_steps)
1042+
1043+
for i, t in enumerate(scheduler.timesteps):
1044+
residual = model(sample, t)
1045+
sample = scheduler.step(residual, t, sample).prev_sample
1046+
1047+
assert sample.dtype == torch.float16
1048+
1049+
8731050
class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
8741051
scheduler_classes = (DPMSolverMultistepScheduler,)
8751052
forward_default_kwargs = (("num_inference_steps", 25),)

0 commit comments

Comments
 (0)