Skip to content

Commit c8b0f0e

Browse files
Update UniPC to support 1D diffusion. (huggingface#5199)
* Update Unipc einsum to support 1D and 3D diffusion. * Add unittest * Update unittest & edge case * Fix unittest * Fix testing_utils.py * Fix unittest file --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7a4324c commit c8b0f0e

File tree

2 files changed

+117
-7
lines changed

2 files changed

+117
-7
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
282282
https://arxiv.org/abs/2205.11487
283283
"""
284284
dtype = sample.dtype
285-
batch_size, channels, height, width = sample.shape
285+
batch_size, channels, *remaining_dims = sample.shape
286286

287287
if dtype not in (torch.float32, torch.float64):
288288
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
289289

290290
# Flatten sample for doing quantile calculation along each image
291-
sample = sample.reshape(batch_size, channels * height * width)
291+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
292292

293293
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
294294

@@ -300,7 +300,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
300300
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
301301
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
302302

303-
sample = sample.reshape(batch_size, channels, height, width)
303+
sample = sample.reshape(batch_size, channels, *remaining_dims)
304304
sample = sample.to(dtype)
305305

306306
return sample
@@ -534,14 +534,14 @@ def multistep_uni_p_bh_update(
534534
if self.predict_x0:
535535
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
536536
if D1s is not None:
537-
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
537+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
538538
else:
539539
pred_res = 0
540540
x_t = x_t_ - alpha_t * B_h * pred_res
541541
else:
542542
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
543543
if D1s is not None:
544-
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
544+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
545545
else:
546546
pred_res = 0
547547
x_t = x_t_ - sigma_t * B_h * pred_res
@@ -670,15 +670,15 @@ def multistep_uni_c_bh_update(
670670
if self.predict_x0:
671671
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
672672
if D1s is not None:
673-
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
673+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
674674
else:
675675
corr_res = 0
676676
D1_t = model_t - m0
677677
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
678678
else:
679679
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
680680
if D1s is not None:
681-
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
681+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
682682
else:
683683
corr_res = 0
684684
D1_t = model_t - m0

tests/schedulers/test_scheduler_unipc.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,113 @@ def test_full_loop_with_noise(self):
269269

270270
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
271271
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
272+
273+
274+
class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
275+
@property
276+
def dummy_sample(self):
277+
batch_size = 4
278+
num_channels = 3
279+
width = 8
280+
281+
sample = torch.rand((batch_size, num_channels, width))
282+
283+
return sample
284+
285+
@property
286+
def dummy_noise_deter(self):
287+
batch_size = 4
288+
num_channels = 3
289+
width = 8
290+
291+
num_elems = batch_size * num_channels * width
292+
sample = torch.arange(num_elems).flip(-1)
293+
sample = sample.reshape(num_channels, width, batch_size)
294+
sample = sample / num_elems
295+
sample = sample.permute(2, 0, 1)
296+
297+
return sample
298+
299+
@property
300+
def dummy_sample_deter(self):
301+
batch_size = 4
302+
num_channels = 3
303+
width = 8
304+
305+
num_elems = batch_size * num_channels * width
306+
sample = torch.arange(num_elems)
307+
sample = sample.reshape(num_channels, width, batch_size)
308+
sample = sample / num_elems
309+
sample = sample.permute(2, 0, 1)
310+
311+
return sample
312+
313+
def test_switch(self):
314+
# make sure that iterating over schedulers with same config names gives same results
315+
# for defaults
316+
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
317+
sample = self.full_loop(scheduler=scheduler)
318+
result_mean = torch.mean(torch.abs(sample))
319+
320+
assert abs(result_mean.item() - 0.2441) < 1e-3
321+
322+
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
323+
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
324+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
325+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
326+
327+
sample = self.full_loop(scheduler=scheduler)
328+
result_mean = torch.mean(torch.abs(sample))
329+
330+
assert abs(result_mean.item() - 0.2441) < 1e-3
331+
332+
def test_full_loop_no_noise(self):
333+
sample = self.full_loop()
334+
result_mean = torch.mean(torch.abs(sample))
335+
336+
assert abs(result_mean.item() - 0.2441) < 1e-3
337+
338+
def test_full_loop_with_karras(self):
339+
sample = self.full_loop(use_karras_sigmas=True)
340+
result_mean = torch.mean(torch.abs(sample))
341+
342+
assert abs(result_mean.item() - 0.2898) < 1e-3
343+
344+
def test_full_loop_with_v_prediction(self):
345+
sample = self.full_loop(prediction_type="v_prediction")
346+
result_mean = torch.mean(torch.abs(sample))
347+
348+
assert abs(result_mean.item() - 0.1014) < 1e-3
349+
350+
def test_full_loop_with_karras_and_v_prediction(self):
351+
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
352+
result_mean = torch.mean(torch.abs(sample))
353+
354+
assert abs(result_mean.item() - 0.1944) < 1e-3
355+
356+
def test_full_loop_with_noise(self):
357+
scheduler_class = self.scheduler_classes[0]
358+
scheduler_config = self.get_scheduler_config()
359+
scheduler = scheduler_class(**scheduler_config)
360+
361+
num_inference_steps = 10
362+
t_start = 8
363+
364+
model = self.dummy_model()
365+
sample = self.dummy_sample_deter
366+
scheduler.set_timesteps(num_inference_steps)
367+
368+
# add noise
369+
noise = self.dummy_noise_deter
370+
timesteps = scheduler.timesteps[t_start * scheduler.order :]
371+
sample = scheduler.add_noise(sample, noise, timesteps[:1])
372+
373+
for i, t in enumerate(timesteps):
374+
residual = model(sample, t)
375+
sample = scheduler.step(residual, t, sample).prev_sample
376+
377+
result_sum = torch.sum(torch.abs(sample))
378+
result_mean = torch.mean(torch.abs(sample))
379+
380+
assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
381+
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"

0 commit comments

Comments
 (0)