Skip to content

Commit e05ca84

Browse files
authored
[ONNX] Support Euler schedulers (huggingface#1328)
1 parent 632dace commit e05ca84

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,10 @@ def __call__(
261261
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
262262

263263
# compute the previous noisy sample x_t -> x_t-1
264-
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
265-
latents = np.array(latents)
264+
scheduler_output = self.scheduler.step(
265+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
266+
)
267+
latents = scheduler_output.prev_sample.numpy()
266268

267269
# call the callback, if provided
268270
if callback is not None and i % callback_steps == 0:

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,10 @@ def __call__(
401401
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
402402

403403
# compute the previous noisy sample x_t -> x_t-1
404-
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
405-
latents = latents.numpy()
404+
scheduler_output = self.scheduler.step(
405+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
406+
)
407+
latents = scheduler_output.prev_sample.numpy()
406408

407409
# call the callback, if provided
408410
if callback is not None and i % callback_steps == 0:

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,10 @@ def __call__(
424424
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
425425

426426
# compute the previous noisy sample x_t -> x_t-1
427-
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
428-
latents = latents.numpy()
427+
scheduler_output = self.scheduler.step(
428+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
429+
)
430+
latents = scheduler_output.prev_sample.numpy()
429431

430432
# call the callback, if provided
431433
if callback is not None and i % callback_steps == 0:

0 commit comments

Comments
 (0)