Skip to content

Commit 940f941

Browse files
yiyixuxuyiyixuxu
andauthored
Add test_full_loop_with_noise tests to all scheduler with add_nosie function (huggingface#5184)
* add fast tests for dpm-multi * add more tests * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent ad06e51 commit 940f941

16 files changed

+477
-0
lines changed

tests/schedulers/test_scheduler_consistency_model.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,45 @@ def test_full_loop_no_noise_multistep(self):
115115
assert abs(result_sum.item() - 347.6357) < 1e-2
116116
assert abs(result_mean.item() - 0.4527) < 1e-3
117117

118+
def test_full_loop_with_noise(self):
119+
scheduler_class = self.scheduler_classes[0]
120+
scheduler_config = self.get_scheduler_config()
121+
scheduler = scheduler_class(**scheduler_config)
122+
123+
num_inference_steps = 10
124+
t_start = 8
125+
126+
scheduler.set_timesteps(num_inference_steps)
127+
timesteps = scheduler.timesteps
128+
129+
generator = torch.manual_seed(0)
130+
131+
model = self.dummy_model()
132+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
133+
134+
noise = self.dummy_noise_deter
135+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
136+
137+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
138+
139+
for t in timesteps:
140+
# 1. scale model input
141+
scaled_sample = scheduler.scale_model_input(sample, t)
142+
143+
# 2. predict noise residual
144+
residual = model(scaled_sample, t)
145+
146+
# 3. predict previous sample x_t-1
147+
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
148+
149+
sample = pred_prev_sample
150+
151+
result_sum = torch.sum(torch.abs(sample))
152+
result_mean = torch.mean(torch.abs(sample))
153+
154+
assert abs(result_sum.item() - 763.9186) < 1e-2, f" expected result sum 763.9186, but get {result_sum}"
155+
assert abs(result_mean.item() - 0.9947) < 1e-3, f" expected result mean 0.9947, but get {result_mean}"
156+
118157
def test_custom_timesteps_increasing_order(self):
119158
scheduler_class = self.scheduler_classes[0]
120159
scheduler_config = self.get_scheduler_config()

tests/schedulers/test_scheduler_ddim.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,31 @@ def test_full_loop_with_no_set_alpha_to_one(self):
146146

147147
assert abs(result_sum.item() - 149.0784) < 1e-2
148148
assert abs(result_mean.item() - 0.1941) < 1e-3
149+
150+
def test_full_loop_with_noise(self):
151+
scheduler_class = self.scheduler_classes[0]
152+
scheduler_config = self.get_scheduler_config()
153+
scheduler = scheduler_class(**scheduler_config)
154+
155+
num_inference_steps, eta = 10, 0.0
156+
t_start = 8
157+
158+
model = self.dummy_model()
159+
sample = self.dummy_sample_deter
160+
161+
scheduler.set_timesteps(num_inference_steps)
162+
163+
# add noise
164+
noise = self.dummy_noise_deter
165+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
166+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
167+
168+
for t in timesteps:
169+
residual = model(sample, t)
170+
sample = scheduler.step(residual, t, sample, eta).prev_sample
171+
172+
result_sum = torch.sum(torch.abs(sample))
173+
result_mean = torch.mean(torch.abs(sample))
174+
175+
assert abs(result_sum.item() - 354.5418) < 1e-2, f" expected result sum 218.4379, but get {result_sum}"
176+
assert abs(result_mean.item() - 0.4616) < 1e-3, f" expected result mean 0.2844, but get {result_mean}"

tests/schedulers/test_scheduler_ddim_parallel.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,31 @@ def test_full_loop_with_no_set_alpha_to_one(self):
186186

187187
assert abs(result_sum.item() - 149.0784) < 1e-2
188188
assert abs(result_mean.item() - 0.1941) < 1e-3
189+
190+
def test_full_loop_with_noise(self):
191+
scheduler_class = self.scheduler_classes[0]
192+
scheduler_config = self.get_scheduler_config()
193+
scheduler = scheduler_class(**scheduler_config)
194+
195+
num_inference_steps, eta = 10, 0.0
196+
t_start = 8
197+
198+
model = self.dummy_model()
199+
sample = self.dummy_sample_deter
200+
201+
scheduler.set_timesteps(num_inference_steps)
202+
203+
# add noise
204+
noise = self.dummy_noise_deter
205+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
206+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
207+
208+
for t in timesteps:
209+
residual = model(sample, t)
210+
sample = scheduler.step(residual, t, sample, eta).prev_sample
211+
212+
result_sum = torch.sum(torch.abs(sample))
213+
result_mean = torch.mean(torch.abs(sample))
214+
215+
assert abs(result_sum.item() - 354.5418) < 1e-2, f" expected result sum 354.5418, but get {result_sum}"
216+
assert abs(result_mean.item() - 0.4616) < 1e-3, f" expected result mean 0.4616, but get {result_mean}"

tests/schedulers/test_scheduler_ddpm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,34 @@ def test_custom_timesteps_too_large(self):
185185
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
186186
):
187187
scheduler.set_timesteps(timesteps=timesteps)
188+
189+
def test_full_loop_with_noise(self):
190+
scheduler_class = self.scheduler_classes[0]
191+
scheduler_config = self.get_scheduler_config()
192+
scheduler = scheduler_class(**scheduler_config)
193+
194+
num_trained_timesteps = len(scheduler)
195+
t_start = num_trained_timesteps - 2
196+
197+
model = self.dummy_model()
198+
sample = self.dummy_sample_deter
199+
generator = torch.manual_seed(0)
200+
201+
# add noise
202+
noise = self.dummy_noise_deter
203+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
204+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
205+
206+
for t in timesteps:
207+
# 1. predict noise residual
208+
residual = model(sample, t)
209+
210+
# 2. predict previous mean of sample x_t-1
211+
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
212+
sample = pred_prev_sample
213+
214+
result_sum = torch.sum(torch.abs(sample))
215+
result_mean = torch.mean(torch.abs(sample))
216+
217+
assert abs(result_sum.item() - 387.9466) < 1e-2, f" expected result sum 387.9466, but get {result_sum}"
218+
assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}"

tests/schedulers/test_scheduler_ddpm_parallel.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,34 @@ def test_custom_timesteps_too_large(self):
214214
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
215215
):
216216
scheduler.set_timesteps(timesteps=timesteps)
217+
218+
def test_full_loop_with_noise(self):
219+
scheduler_class = self.scheduler_classes[0]
220+
scheduler_config = self.get_scheduler_config()
221+
scheduler = scheduler_class(**scheduler_config)
222+
223+
num_trained_timesteps = len(scheduler)
224+
t_start = num_trained_timesteps - 2
225+
226+
model = self.dummy_model()
227+
sample = self.dummy_sample_deter
228+
generator = torch.manual_seed(0)
229+
230+
# add noise
231+
noise = self.dummy_noise_deter
232+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
233+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
234+
235+
for t in timesteps:
236+
# 1. predict noise residual
237+
residual = model(sample, t)
238+
239+
# 2. predict previous mean of sample x_t-1
240+
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
241+
sample = pred_prev_sample
242+
243+
result_sum = torch.sum(torch.abs(sample))
244+
result_mean = torch.mean(torch.abs(sample))
245+
246+
assert abs(result_sum.item() - 387.9466) < 1e-2, f" expected result sum 387.9466, but get {result_sum}"
247+
assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}"

tests/schedulers/test_scheduler_deis.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,30 @@ def test_fp16_support(self):
236236
sample = scheduler.step(residual, t, sample).prev_sample
237237

238238
assert sample.dtype == torch.float16
239+
240+
def test_full_loop_with_noise(self):
241+
scheduler_class = self.scheduler_classes[0]
242+
scheduler_config = self.get_scheduler_config()
243+
scheduler = scheduler_class(**scheduler_config)
244+
245+
num_inference_steps = 10
246+
t_start = 8
247+
248+
model = self.dummy_model()
249+
sample = self.dummy_sample_deter
250+
scheduler.set_timesteps(num_inference_steps)
251+
252+
# add noise
253+
noise = self.dummy_noise_deter
254+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
255+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
256+
257+
for i, t in enumerate(timesteps):
258+
residual = model(sample, t)
259+
sample = scheduler.step(residual, t, sample).prev_sample
260+
261+
result_sum = torch.sum(torch.abs(sample))
262+
result_mean = torch.mean(torch.abs(sample))
263+
264+
assert abs(result_sum.item() - 315.3016) < 1e-2, f" expected result sum 315.3016, but get {result_sum}"
265+
assert abs(result_mean.item() - 0.41054) < 1e-3, f" expected result mean 0.41054, but get {result_mean}"

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,33 @@ def test_full_loop_no_noise(self):
213213

214214
assert abs(result_mean.item() - 0.3301) < 1e-3
215215

216+
def test_full_loop_with_noise(self):
217+
scheduler_class = self.scheduler_classes[0]
218+
scheduler_config = self.get_scheduler_config()
219+
scheduler = scheduler_class(**scheduler_config)
220+
221+
num_inference_steps = 10
222+
t_start = 5
223+
224+
model = self.dummy_model()
225+
sample = self.dummy_sample_deter
226+
scheduler.set_timesteps(num_inference_steps)
227+
228+
# add noise
229+
noise = self.dummy_noise_deter
230+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
231+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
232+
233+
for i, t in enumerate(timesteps):
234+
residual = model(sample, t)
235+
sample = scheduler.step(residual, t, sample).prev_sample
236+
237+
result_sum = torch.sum(torch.abs(sample))
238+
result_mean = torch.mean(torch.abs(sample))
239+
240+
assert abs(result_sum.item() - 318.4111) < 1e-2, f" expected result sum 318.4111, but get {result_sum}"
241+
assert abs(result_mean.item() - 0.4146) < 1e-3, f" expected result mean 0.4146, but get {result_mean}"
242+
216243
def test_full_loop_no_noise_thres(self):
217244
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
218245
result_mean = torch.mean(torch.abs(sample))

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,30 @@ def test_step_shape(self):
279279

280280
self.assertEqual(output_0.shape, sample.shape)
281281
self.assertEqual(output_0.shape, output_1.shape)
282+
283+
def test_full_loop_with_noise(self):
284+
scheduler_class = self.scheduler_classes[0]
285+
scheduler_config = self.get_scheduler_config()
286+
scheduler = scheduler_class(**scheduler_config)
287+
288+
num_inference_steps = 10
289+
t_start = 5
290+
291+
model = self.dummy_model()
292+
sample = self.dummy_sample_deter
293+
scheduler.set_timesteps(num_inference_steps)
294+
295+
# add noise
296+
noise = self.dummy_noise_deter
297+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
298+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
299+
300+
for i, t in enumerate(timesteps):
301+
residual = model(sample, t)
302+
sample = scheduler.step(residual, t, sample).prev_sample
303+
304+
result_sum = torch.sum(torch.abs(sample))
305+
result_mean = torch.mean(torch.abs(sample))
306+
307+
assert abs(result_sum.item() - 269.2187) < 1e-2, f" expected result sum 269.2187, but get {result_sum}"
308+
assert abs(result_mean.item() - 0.3505) < 1e-3, f" expected result mean 0.3505, but get {result_mean}"

tests/schedulers/test_scheduler_euler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,36 @@ def test_full_loop_device_karras_sigmas(self):
144144

145145
assert abs(result_sum.item() - 124.52299499511719) < 1e-2
146146
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3
147+
148+
def test_full_loop_with_noise(self):
149+
scheduler_class = self.scheduler_classes[0]
150+
scheduler_config = self.get_scheduler_config()
151+
scheduler = scheduler_class(**scheduler_config)
152+
153+
scheduler.set_timesteps(self.num_inference_steps)
154+
155+
generator = torch.manual_seed(0)
156+
157+
model = self.dummy_model()
158+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
159+
160+
# add noise
161+
t_start = self.num_inference_steps - 2
162+
noise = self.dummy_noise_deter
163+
noise = noise.to(sample.device)
164+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
165+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
166+
167+
for i, t in enumerate(timesteps):
168+
sample = scheduler.scale_model_input(sample, t)
169+
170+
model_output = model(sample, t)
171+
172+
output = scheduler.step(model_output, t, sample, generator=generator)
173+
sample = output.prev_sample
174+
175+
result_sum = torch.sum(torch.abs(sample))
176+
result_mean = torch.mean(torch.abs(sample))
177+
178+
assert abs(result_sum.item() - 57062.9297) < 1e-2, f" expected result sum 57062.9297, but get {result_sum}"
179+
assert abs(result_mean.item() - 74.3007) < 1e-3, f" expected result mean 74.3007, but get {result_mean}"

tests/schedulers/test_scheduler_euler_ancestral.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,37 @@ def test_full_loop_device(self):
116116

117117
assert abs(result_sum.item() - 152.3192) < 1e-2
118118
assert abs(result_mean.item() - 0.1983) < 1e-3
119+
120+
def test_full_loop_with_noise(self):
121+
scheduler_class = self.scheduler_classes[0]
122+
scheduler_config = self.get_scheduler_config()
123+
scheduler = scheduler_class(**scheduler_config)
124+
125+
t_start = self.num_inference_steps - 2
126+
127+
scheduler.set_timesteps(self.num_inference_steps)
128+
129+
generator = torch.manual_seed(0)
130+
131+
model = self.dummy_model()
132+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
133+
134+
# add noise
135+
noise = self.dummy_noise_deter
136+
noise = noise.to(sample.device)
137+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
138+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
139+
140+
for i, t in enumerate(timesteps):
141+
sample = scheduler.scale_model_input(sample, t)
142+
143+
model_output = model(sample, t)
144+
145+
output = scheduler.step(model_output, t, sample, generator=generator)
146+
sample = output.prev_sample
147+
148+
result_sum = torch.sum(torch.abs(sample))
149+
result_mean = torch.mean(torch.abs(sample))
150+
151+
assert abs(result_sum.item() - 56163.0508) < 1e-2, f" expected result sum 56163.0508, but get {result_sum}"
152+
assert abs(result_mean.item() - 73.1290) < 1e-3, f" expected result mean 73.1290, but get {result_mean}"

0 commit comments

Comments
 (0)