Skip to content

Commit 11f7d6f

Browse files
authored
[ONNX] Improve ONNXPipeline scheduler compatibility, fix safety_checker (huggingface#1173)
* [ONNX] Improve ONNX scheduler compatibility, fix safety_checker * typo
1 parent 555203e commit 11f7d6f

10 files changed

+346
-89
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
8181
output_path = Path(output_path)
8282

8383
# TEXT ENCODER
84+
num_tokens = pipeline.text_encoder.config.max_position_embeddings
85+
text_hidden_size = pipeline.text_encoder.config.hidden_size
8486
text_input = pipeline.tokenizer(
8587
"A sample prompt",
8688
padding="max_length",
@@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
103105
del pipeline.text_encoder
104106

105107
# UNET
108+
unet_in_channels = pipeline.unet.config.in_channels
109+
unet_sample_size = pipeline.unet.config.sample_size
106110
unet_path = output_path / "unet" / "model.onnx"
107111
onnx_export(
108112
pipeline.unet,
109113
model_args=(
110-
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
111-
torch.LongTensor([0, 1]).to(device=device),
112-
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
114+
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
115+
torch.randn(2).to(device=device, dtype=dtype),
116+
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
113117
False,
114118
),
115119
output_path=unet_path,
@@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
142146

143147
# VAE ENCODER
144148
vae_encoder = pipeline.vae
149+
vae_in_channels = vae_encoder.config.in_channels
150+
vae_sample_size = vae_encoder.config.sample_size
145151
# need to get the raw tensor output (sample) from the encoder
146152
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
147153
onnx_export(
148154
vae_encoder,
149-
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
155+
model_args=(
156+
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
157+
False,
158+
),
150159
output_path=output_path / "vae_encoder" / "model.onnx",
151160
ordered_input_names=["sample", "return_dict"],
152161
output_names=["latent_sample"],
@@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
158167

159168
# VAE DECODER
160169
vae_decoder = pipeline.vae
170+
vae_latent_channels = vae_decoder.config.latent_channels
171+
vae_out_channels = vae_decoder.config.out_channels
161172
# forward only through the decoder part
162173
vae_decoder.forward = vae_encoder.decode
163174
onnx_export(
164175
vae_decoder,
165-
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
176+
model_args=(
177+
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
178+
False,
179+
),
166180
output_path=output_path / "vae_decoder" / "model.onnx",
167181
ordered_input_names=["latent_sample", "return_dict"],
168182
output_names=["sample"],
@@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
174188
del pipeline.vae
175189

176190
# SAFETY CHECKER
177-
safety_checker = pipeline.safety_checker
178-
safety_checker.forward = safety_checker.forward_onnx
179-
onnx_export(
180-
pipeline.safety_checker,
181-
model_args=(
182-
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
183-
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
184-
),
185-
output_path=output_path / "safety_checker" / "model.onnx",
186-
ordered_input_names=["clip_input", "images"],
187-
output_names=["out_images", "has_nsfw_concepts"],
188-
dynamic_axes={
189-
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
190-
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
191-
},
192-
opset=opset,
193-
)
194-
del pipeline.safety_checker
191+
if pipeline.safety_checker is not None:
192+
safety_checker = pipeline.safety_checker
193+
clip_num_channels = safety_checker.config.vision_config.num_channels
194+
clip_image_size = safety_checker.config.vision_config.image_size
195+
safety_checker.forward = safety_checker.forward_onnx
196+
onnx_export(
197+
pipeline.safety_checker,
198+
model_args=(
199+
torch.randn(
200+
1,
201+
clip_num_channels,
202+
clip_image_size,
203+
clip_image_size,
204+
).to(device=device, dtype=dtype),
205+
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
206+
),
207+
output_path=output_path / "safety_checker" / "model.onnx",
208+
ordered_input_names=["clip_input", "images"],
209+
output_names=["out_images", "has_nsfw_concepts"],
210+
dynamic_axes={
211+
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
212+
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
213+
},
214+
opset=opset,
215+
)
216+
del pipeline.safety_checker
217+
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
218+
else:
219+
safety_checker = None
195220

196221
onnx_pipeline = OnnxStableDiffusionPipeline(
197222
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
200225
tokenizer=pipeline.tokenizer,
201226
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
202227
scheduler=pipeline.scheduler,
203-
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
228+
safety_checker=safety_checker,
204229
feature_extractor=pipeline.feature_extractor,
205230
)
206231

src/diffusers/onnx_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from huggingface_hub import hf_hub_download
2626

27-
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
27+
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
2828

2929

3030
if is_onnx_available():
@@ -33,13 +33,28 @@
3333

3434
logger = logging.get_logger(__name__)
3535

36+
ORT_TO_NP_TYPE = {
37+
"tensor(bool)": np.bool_,
38+
"tensor(int8)": np.int8,
39+
"tensor(uint8)": np.uint8,
40+
"tensor(int16)": np.int16,
41+
"tensor(uint16)": np.uint16,
42+
"tensor(int32)": np.int32,
43+
"tensor(uint32)": np.uint32,
44+
"tensor(int64)": np.int64,
45+
"tensor(uint64)": np.uint64,
46+
"tensor(float16)": np.float16,
47+
"tensor(float)": np.float32,
48+
"tensor(double)": np.float64,
49+
}
50+
3651

3752
class OnnxRuntimeModel:
3853
def __init__(self, model=None, **kwargs):
3954
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
4055
self.model = model
4156
self.model_save_dir = kwargs.get("model_save_dir", None)
42-
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
57+
self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
4358

4459
def __call__(self, **kwargs):
4560
inputs = {k: np.array(v) for k, v in kwargs.items()}
@@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
8499
except shutil.SameFileError:
85100
pass
86101

102+
# copy external weights (for models >2GB)
103+
src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
104+
if src_path.exists():
105+
dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
106+
try:
107+
shutil.copyfile(src_path, dst_path)
108+
except shutil.SameFileError:
109+
pass
110+
87111
def save_pretrained(
88112
self,
89113
save_directory: Union[str, os.PathLike],

src/diffusers/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
541541
# if the model is in a pipeline module, then we load it from the pipeline
542542
if name in passed_class_obj:
543543
# 1. check that passed_class_obj has correct parent class
544-
if not is_pipeline_module:
544+
if not is_pipeline_module and passed_class_obj[name] is not None:
545545
library = importlib.import_module(library_name)
546546
class_obj = getattr(library, class_name)
547547
importable_classes = LOADABLE_CLASSES[library_name]

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from typing import Callable, List, Optional, Union
33

44
import numpy as np
5+
import torch
56

67
from transformers import CLIPFeatureExtractor, CLIPTokenizer
78

89
from ...configuration_utils import FrozenDict
9-
from ...onnx_utils import OnnxRuntimeModel
10+
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1011
from ...pipeline_utils import DiffusionPipeline
1112
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1213
from ...utils import deprecate, logging
@@ -186,7 +187,7 @@ def __call__(
186187
# set timesteps
187188
self.scheduler.set_timesteps(num_inference_steps)
188189

189-
latents = latents * self.scheduler.init_noise_sigma
190+
latents = latents * np.float(self.scheduler.init_noise_sigma)
190191

191192
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
192193
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -197,15 +198,20 @@ def __call__(
197198
if accepts_eta:
198199
extra_step_kwargs["eta"] = eta
199200

201+
timestep_dtype = next(
202+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
203+
)
204+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
205+
200206
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
201207
# expand the latents if we are doing classifier free guidance
202208
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
203-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
209+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
210+
latent_model_input = latent_model_input.cpu().numpy()
204211

205212
# predict the noise residual
206-
noise_pred = self.unet(
207-
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
208-
)
213+
timestep = np.array([t], dtype=timestep_dtype)
214+
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
209215
noise_pred = noise_pred[0]
210216

211217
# perform guidance
@@ -214,7 +220,7 @@ def __call__(
214220
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
215221

216222
# compute the previous noisy sample x_t -> x_t-1
217-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
223+
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
218224
latents = np.array(latents)
219225

220226
# call the callback, if provided
@@ -235,6 +241,9 @@ def __call__(
235241
safety_checker_input = self.feature_extractor(
236242
self.numpy_to_pil(image), return_tensors="np"
237243
).pixel_values.astype(image.dtype)
244+
245+
image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image)
246+
238247
# There will throw an error if use safety_checker batchsize>1
239248
images, has_nsfw_concept = [], []
240249
for i in range(image.shape[0]):

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import CLIPFeatureExtractor, CLIPTokenizer
99

1010
from ...configuration_utils import FrozenDict
11-
from ...onnx_utils import OnnxRuntimeModel
11+
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1212
from ...pipeline_utils import DiffusionPipeline
1313
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1414
from ...utils import deprecate, logging
@@ -338,14 +338,21 @@ def __call__(
338338
t_start = max(num_inference_steps - init_timestep + offset, 0)
339339
timesteps = self.scheduler.timesteps[t_start:].numpy()
340340

341+
timestep_dtype = next(
342+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
343+
)
344+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
345+
341346
for i, t in enumerate(self.progress_bar(timesteps)):
342347
# expand the latents if we are doing classifier free guidance
343348
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
344-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
349+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
350+
latent_model_input = latent_model_input.cpu().numpy()
345351

346352
# predict the noise residual
353+
timestep = np.array([t], dtype=timestep_dtype)
347354
noise_pred = self.unet(
348-
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
355+
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
349356
)[0]
350357

351358
# perform guidance
@@ -354,7 +361,7 @@ def __call__(
354361
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
355362

356363
# compute the previous noisy sample x_t -> x_t-1
357-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
364+
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
358365
latents = latents.numpy()
359366

360367
# call the callback, if provided
@@ -375,7 +382,7 @@ def __call__(
375382
safety_checker_input = self.feature_extractor(
376383
self.numpy_to_pil(image), return_tensors="np"
377384
).pixel_values.astype(image.dtype)
378-
# There will throw an error if use safety_checker batchsize>1
385+
# safety_checker does not support batched inputs yet
379386
images, has_nsfw_concept = [], []
380387
for i in range(image.shape[0]):
381388
image_i, has_nsfw_concept_i = self.safety_checker(

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import CLIPFeatureExtractor, CLIPTokenizer
99

1010
from ...configuration_utils import FrozenDict
11-
from ...onnx_utils import OnnxRuntimeModel
11+
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1212
from ...pipeline_utils import DiffusionPipeline
1313
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1414
from ...utils import deprecate, logging
@@ -352,7 +352,7 @@ def __call__(
352352
self.scheduler.set_timesteps(num_inference_steps)
353353

354354
# scale the initial noise by the standard deviation required by the scheduler
355-
latents = latents * self.scheduler.init_noise_sigma
355+
latents = latents * np.float(self.scheduler.init_noise_sigma)
356356

357357
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
358358
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -363,17 +363,23 @@ def __call__(
363363
if accepts_eta:
364364
extra_step_kwargs["eta"] = eta
365365

366+
timestep_dtype = next(
367+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
368+
)
369+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
370+
366371
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
367372
# expand the latents if we are doing classifier free guidance
368373
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
369374
# concat latents, mask, masked_image_latnets in the channel dimension
370375
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
371376
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
372-
latent_model_input = latent_model_input.numpy()
377+
latent_model_input = latent_model_input.cpu().numpy()
373378

374379
# predict the noise residual
380+
timestep = np.array([t], dtype=timestep_dtype)
375381
noise_pred = self.unet(
376-
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
382+
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
377383
)[0]
378384

379385
# perform guidance
@@ -382,7 +388,7 @@ def __call__(
382388
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
383389

384390
# compute the previous noisy sample x_t -> x_t-1
385-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
391+
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
386392
latents = latents.numpy()
387393

388394
# call the callback, if provided
@@ -403,7 +409,7 @@ def __call__(
403409
safety_checker_input = self.feature_extractor(
404410
self.numpy_to_pil(image), return_tensors="np"
405411
).pixel_values.astype(image.dtype)
406-
# There will throw an error if use safety_checker batchsize>1
412+
# safety_checker does not support batched inputs yet
407413
images, has_nsfw_concept = [], []
408414
for i in range(image.shape[0]):
409415
image_i, has_nsfw_concept_i = self.safety_checker(

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
6868
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
6969
ONNX_WEIGHTS_NAME = "model.onnx"
70+
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
7071
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
7172
DIFFUSERS_CACHE = default_cache_path
7273
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"

0 commit comments

Comments
 (0)