Skip to content

Commit 7547f9b

Browse files
authored
Fix timestep dtype in legacy inpaint (huggingface#2120)
* Fix timestep dtype in legacy inpaint This matches the structure in the text2img, img2img, and inpaint ONNX pipelines * Fix style in dtype patch
1 parent a87e87f commit 7547f9b

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ...configuration_utils import FrozenDict
1111
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1212
from ...utils import deprecate, logging
13-
from ..onnx_utils import OnnxRuntimeModel
13+
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1414
from ..pipeline_utils import DiffusionPipeline
1515
from . import StableDiffusionPipelineOutput
1616

@@ -391,16 +391,21 @@ def __call__(
391391

392392
t_start = max(num_inference_steps - init_timestep + offset, 0)
393393
timesteps = self.scheduler.timesteps[t_start:].numpy()
394+
timestep_dtype = next(
395+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
396+
)
397+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
394398

395399
for i, t in enumerate(self.progress_bar(timesteps)):
396400
# expand the latents if we are doing classifier free guidance
397401
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
398402
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
399403

400404
# predict the noise residual
401-
noise_pred = self.unet(
402-
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds
403-
)[0]
405+
timestep = np.array([t], dtype=timestep_dtype)
406+
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
407+
0
408+
]
404409

405410
# perform guidance
406411
if do_classifier_free_guidance:

0 commit comments

Comments
 (0)