|
28 | 28 | DDIMScheduler,
|
29 | 29 | DDPMScheduler,
|
30 | 30 | DPMSolverMultistepScheduler,
|
| 31 | + DPMSolverSinglestepScheduler, |
31 | 32 | EulerAncestralDiscreteScheduler,
|
32 | 33 | EulerDiscreteScheduler,
|
33 | 34 | HeunDiscreteScheduler,
|
@@ -870,6 +871,182 @@ def test_full_loop_with_no_set_alpha_to_one(self):
|
870 | 871 | assert abs(result_mean.item() - 0.1941) < 1e-3
|
871 | 872 |
|
872 | 873 |
|
| 874 | +class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): |
| 875 | + scheduler_classes = (DPMSolverSinglestepScheduler,) |
| 876 | + forward_default_kwargs = (("num_inference_steps", 25),) |
| 877 | + |
| 878 | + def get_scheduler_config(self, **kwargs): |
| 879 | + config = { |
| 880 | + "num_train_timesteps": 1000, |
| 881 | + "beta_start": 0.0001, |
| 882 | + "beta_end": 0.02, |
| 883 | + "beta_schedule": "linear", |
| 884 | + "solver_order": 2, |
| 885 | + "prediction_type": "epsilon", |
| 886 | + "thresholding": False, |
| 887 | + "sample_max_value": 1.0, |
| 888 | + "algorithm_type": "dpmsolver++", |
| 889 | + "solver_type": "midpoint", |
| 890 | + } |
| 891 | + |
| 892 | + config.update(**kwargs) |
| 893 | + return config |
| 894 | + |
| 895 | + def check_over_configs(self, time_step=0, **config): |
| 896 | + kwargs = dict(self.forward_default_kwargs) |
| 897 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 898 | + sample = self.dummy_sample |
| 899 | + residual = 0.1 * sample |
| 900 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 901 | + |
| 902 | + for scheduler_class in self.scheduler_classes: |
| 903 | + scheduler_config = self.get_scheduler_config(**config) |
| 904 | + scheduler = scheduler_class(**scheduler_config) |
| 905 | + scheduler.set_timesteps(num_inference_steps) |
| 906 | + # copy over dummy past residuals |
| 907 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] |
| 908 | + |
| 909 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 910 | + scheduler.save_config(tmpdirname) |
| 911 | + new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
| 912 | + new_scheduler.set_timesteps(num_inference_steps) |
| 913 | + # copy over dummy past residuals |
| 914 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] |
| 915 | + |
| 916 | + output, new_output = sample, sample |
| 917 | + for t in range(time_step, time_step + scheduler.config.solver_order + 1): |
| 918 | + output = scheduler.step(residual, t, output, **kwargs).prev_sample |
| 919 | + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample |
| 920 | + |
| 921 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 922 | + |
| 923 | + def test_from_save_pretrained(self): |
| 924 | + pass |
| 925 | + |
| 926 | + def check_over_forward(self, time_step=0, **forward_kwargs): |
| 927 | + kwargs = dict(self.forward_default_kwargs) |
| 928 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 929 | + sample = self.dummy_sample |
| 930 | + residual = 0.1 * sample |
| 931 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 932 | + |
| 933 | + for scheduler_class in self.scheduler_classes: |
| 934 | + scheduler_config = self.get_scheduler_config() |
| 935 | + scheduler = scheduler_class(**scheduler_config) |
| 936 | + scheduler.set_timesteps(num_inference_steps) |
| 937 | + |
| 938 | + # copy over dummy past residuals (must be after setting timesteps) |
| 939 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] |
| 940 | + |
| 941 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 942 | + scheduler.save_config(tmpdirname) |
| 943 | + new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
| 944 | + # copy over dummy past residuals |
| 945 | + new_scheduler.set_timesteps(num_inference_steps) |
| 946 | + |
| 947 | + # copy over dummy past residual (must be after setting timesteps) |
| 948 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] |
| 949 | + |
| 950 | + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 951 | + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 952 | + |
| 953 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 954 | + |
| 955 | + def full_loop(self, **config): |
| 956 | + scheduler_class = self.scheduler_classes[0] |
| 957 | + scheduler_config = self.get_scheduler_config(**config) |
| 958 | + scheduler = scheduler_class(**scheduler_config) |
| 959 | + |
| 960 | + num_inference_steps = 10 |
| 961 | + model = self.dummy_model() |
| 962 | + sample = self.dummy_sample_deter |
| 963 | + scheduler.set_timesteps(num_inference_steps) |
| 964 | + |
| 965 | + for i, t in enumerate(scheduler.timesteps): |
| 966 | + residual = model(sample, t) |
| 967 | + sample = scheduler.step(residual, t, sample).prev_sample |
| 968 | + |
| 969 | + return sample |
| 970 | + |
| 971 | + def test_timesteps(self): |
| 972 | + for timesteps in [25, 50, 100, 999, 1000]: |
| 973 | + self.check_over_configs(num_train_timesteps=timesteps) |
| 974 | + |
| 975 | + def test_thresholding(self): |
| 976 | + self.check_over_configs(thresholding=False) |
| 977 | + for order in [1, 2, 3]: |
| 978 | + for solver_type in ["midpoint", "heun"]: |
| 979 | + for threshold in [0.5, 1.0, 2.0]: |
| 980 | + for prediction_type in ["epsilon", "sample"]: |
| 981 | + self.check_over_configs( |
| 982 | + thresholding=True, |
| 983 | + prediction_type=prediction_type, |
| 984 | + sample_max_value=threshold, |
| 985 | + algorithm_type="dpmsolver++", |
| 986 | + solver_order=order, |
| 987 | + solver_type=solver_type, |
| 988 | + ) |
| 989 | + |
| 990 | + def test_prediction_type(self): |
| 991 | + for prediction_type in ["epsilon", "v_prediction"]: |
| 992 | + self.check_over_configs(prediction_type=prediction_type) |
| 993 | + |
| 994 | + def test_solver_order_and_type(self): |
| 995 | + for algorithm_type in ["dpmsolver", "dpmsolver++"]: |
| 996 | + for solver_type in ["midpoint", "heun"]: |
| 997 | + for order in [1, 2, 3]: |
| 998 | + for prediction_type in ["epsilon", "sample"]: |
| 999 | + self.check_over_configs( |
| 1000 | + solver_order=order, |
| 1001 | + solver_type=solver_type, |
| 1002 | + prediction_type=prediction_type, |
| 1003 | + algorithm_type=algorithm_type, |
| 1004 | + ) |
| 1005 | + sample = self.full_loop( |
| 1006 | + solver_order=order, |
| 1007 | + solver_type=solver_type, |
| 1008 | + prediction_type=prediction_type, |
| 1009 | + algorithm_type=algorithm_type, |
| 1010 | + ) |
| 1011 | + assert not torch.isnan(sample).any(), "Samples have nan numbers" |
| 1012 | + |
| 1013 | + def test_lower_order_final(self): |
| 1014 | + self.check_over_configs(lower_order_final=True) |
| 1015 | + self.check_over_configs(lower_order_final=False) |
| 1016 | + |
| 1017 | + def test_inference_steps(self): |
| 1018 | + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: |
| 1019 | + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) |
| 1020 | + |
| 1021 | + def test_full_loop_no_noise(self): |
| 1022 | + sample = self.full_loop() |
| 1023 | + result_mean = torch.mean(torch.abs(sample)) |
| 1024 | + |
| 1025 | + assert abs(result_mean.item() - 0.2791) < 1e-3 |
| 1026 | + |
| 1027 | + def test_full_loop_with_v_prediction(self): |
| 1028 | + sample = self.full_loop(prediction_type="v_prediction") |
| 1029 | + result_mean = torch.mean(torch.abs(sample)) |
| 1030 | + |
| 1031 | + assert abs(result_mean.item() - 0.1453) < 1e-3 |
| 1032 | + |
| 1033 | + def test_fp16_support(self): |
| 1034 | + scheduler_class = self.scheduler_classes[0] |
| 1035 | + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) |
| 1036 | + scheduler = scheduler_class(**scheduler_config) |
| 1037 | + |
| 1038 | + num_inference_steps = 10 |
| 1039 | + model = self.dummy_model() |
| 1040 | + sample = self.dummy_sample_deter.half() |
| 1041 | + scheduler.set_timesteps(num_inference_steps) |
| 1042 | + |
| 1043 | + for i, t in enumerate(scheduler.timesteps): |
| 1044 | + residual = model(sample, t) |
| 1045 | + sample = scheduler.step(residual, t, sample).prev_sample |
| 1046 | + |
| 1047 | + assert sample.dtype == torch.float16 |
| 1048 | + |
| 1049 | + |
873 | 1050 | class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
874 | 1051 | scheduler_classes = (DPMSolverMultistepScheduler,)
|
875 | 1052 | forward_default_kwargs = (("num_inference_steps", 25),)
|
|
0 commit comments