Skip to content

Commit b345c74

Browse files
Make sure all pipelines can run with batched input (huggingface#1669)
* [SD] Make sure batched input works correctly * uP * uP * up * up * uP * up * fix mask stuff * up * uP * more up * up * uP * up * finish * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent b417042 commit b345c74

24 files changed

+336
-152
lines changed

src/diffusers/models/unet_1d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def forward(
218218
else:
219219
timestep_embed = timestep_embed[..., None]
220220
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
221+
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
221222

222223
# 2. down
223224
down_block_res_samples = ()

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
249249
return_tensors="pt",
250250
)
251251
text_input_ids = text_inputs.input_ids
252-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
252+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
253253

254-
if not torch.equal(text_input_ids, untruncated_ids):
254+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
255255
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
256256
logger.warning(
257257
"The following part of your input was truncated because CLIP can only handle sequences up to"

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,24 @@
4444

4545
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
4646
def preprocess(image):
47-
w, h = image.size
48-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
49-
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
50-
image = np.array(image).astype(np.float32) / 255.0
51-
image = image[None].transpose(0, 3, 1, 2)
52-
image = torch.from_numpy(image)
53-
return 2.0 * image - 1.0
47+
if isinstance(image, torch.Tensor):
48+
return image
49+
elif isinstance(image, PIL.Image.Image):
50+
image = [image]
51+
52+
if isinstance(image[0], PIL.Image.Image):
53+
w, h = image[0].size
54+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
55+
56+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
57+
image = np.concatenate(image, axis=0)
58+
image = np.array(image).astype(np.float32) / 255.0
59+
image = image.transpose(0, 3, 1, 2)
60+
image = 2.0 * image - 1.0
61+
image = torch.from_numpy(image)
62+
elif isinstance(image[0], torch.Tensor):
63+
image = torch.cat(image, dim=0)
64+
return image
5465

5566

5667
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
@@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
8192
feature_extractor ([`CLIPFeatureExtractor`]):
8293
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
8394
"""
84-
_optional_components = ["safety_checker", "feature_extractor"]
95+
_optional_components = ["safety_checker"]
8596

8697
def __init__(
8798
self,
@@ -246,9 +257,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
246257
return_tensors="pt",
247258
)
248259
text_input_ids = text_inputs.input_ids
249-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
260+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
250261

251-
if not torch.equal(text_input_ids, untruncated_ids):
262+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
252263
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
253264
logger.warning(
254265
"The following part of your input was truncated because CLIP can only handle sequences up to"
@@ -510,8 +521,7 @@ def __call__(
510521
)
511522

512523
# 4. Preprocess image
513-
if isinstance(image, PIL.Image.Image):
514-
image = preprocess(image)
524+
image = preprocess(image)
515525

516526
# 5. set timesteps
517527
self.scheduler.set_timesteps(num_inference_steps, device=device)

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def __call__(
4646
use_clipped_model_output: Optional[bool] = None,
4747
output_type: Optional[str] = "pil",
4848
return_dict: bool = True,
49-
**kwargs,
5049
) -> Union[ImagePipelineOutput, Tuple]:
5150
r"""
5251
Args:

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask):
109109
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
110110
else:
111111
if isinstance(image, PIL.Image.Image):
112-
image = np.array(image.convert("RGB"))
112+
image = [image]
113113

114-
image = image[None].transpose(0, 3, 1, 2)
114+
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
115+
image = image.transpose(0, 3, 1, 2)
115116
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
116117

118+
# preprocess mask
117119
if isinstance(mask, PIL.Image.Image):
118-
mask = np.array(mask.convert("L"))
119-
mask = mask.astype(np.float32) / 255.0
120+
mask = [mask]
120121

121-
mask = mask[None, None]
122+
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
123+
mask = mask.astype(np.float32) / 255.0
122124

123125
# paint-by-example inverses the mask
124126
mask = 1 - mask
@@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
159161
feature_extractor ([`CLIPFeatureExtractor`]):
160162
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
161163
"""
162-
_optional_components = ["safety_checker", "feature_extractor"]
164+
_optional_components = ["safety_checker"]
163165

164166
def __init__(
165167
self,
@@ -323,8 +325,22 @@ def prepare_mask_latents(
323325
masked_image_latents = 0.18215 * masked_image_latents
324326

325327
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
326-
mask = mask.repeat(batch_size, 1, 1, 1)
327-
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
328+
if mask.shape[0] < batch_size:
329+
if not batch_size % mask.shape[0] == 0:
330+
raise ValueError(
331+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
332+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
333+
" of masks that you pass is divisible by the total requested batch size."
334+
)
335+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
336+
if masked_image_latents.shape[0] < batch_size:
337+
if not batch_size % masked_image_latents.shape[0] == 0:
338+
raise ValueError(
339+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
340+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
341+
" Make sure the number of images that you pass is divisible by the total requested batch size."
342+
)
343+
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
328344

329345
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
330346
masked_image_latents = (
@@ -351,7 +367,7 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free
351367

352368
if do_classifier_free_guidance:
353369
uncond_embeddings = self.image_encoder.uncond_vector
354-
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
370+
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
355371
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
356372

357373
# For classifier free guidance, we need to do two forward passes.

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,26 @@
3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3636

3737

38+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
3839
def preprocess(image):
39-
w, h = image.size
40-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
41-
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
42-
image = np.array(image).astype(np.float32) / 255.0
43-
image = image[None].transpose(0, 3, 1, 2)
44-
image = torch.from_numpy(image)
45-
return 2.0 * image - 1.0
40+
if isinstance(image, torch.Tensor):
41+
return image
42+
elif isinstance(image, PIL.Image.Image):
43+
image = [image]
44+
45+
if isinstance(image[0], PIL.Image.Image):
46+
w, h = image[0].size
47+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
48+
49+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
50+
image = np.concatenate(image, axis=0)
51+
image = np.array(image).astype(np.float32) / 255.0
52+
image = image.transpose(0, 3, 1, 2)
53+
image = 2.0 * image - 1.0
54+
image = torch.from_numpy(image)
55+
elif isinstance(image[0], torch.Tensor):
56+
image = torch.cat(image, dim=0)
57+
return image
4658

4759

4860
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
@@ -279,9 +291,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
279291
return_tensors="pt",
280292
)
281293
text_input_ids = text_inputs.input_ids
282-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
294+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
283295

284-
if not torch.equal(text_input_ids, untruncated_ids):
296+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
285297
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
286298
logger.warning(
287299
"The following part of your input was truncated because CLIP can only handle sequences up to"
@@ -551,8 +563,7 @@ def __call__(
551563
)
552564

553565
# 4. Preprocess image
554-
if isinstance(image, PIL.Image.Image):
555-
image = preprocess(image)
566+
image = preprocess(image)
556567

557568
# 5. Prepare timesteps
558569
self.scheduler.set_timesteps(num_inference_steps, device=device)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,26 @@
3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3333

3434

35+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
3536
def preprocess(image):
36-
w, h = image.size
37-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
38-
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
39-
image = np.array(image).astype(np.float32) / 255.0
40-
image = image[None].transpose(0, 3, 1, 2)
41-
return 2.0 * image - 1.0
37+
if isinstance(image, torch.Tensor):
38+
return image
39+
elif isinstance(image, PIL.Image.Image):
40+
image = [image]
41+
42+
if isinstance(image[0], PIL.Image.Image):
43+
w, h = image[0].size
44+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
45+
46+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
47+
image = np.concatenate(image, axis=0)
48+
image = np.array(image).astype(np.float32) / 255.0
49+
image = image.transpose(0, 3, 1, 2)
50+
image = 2.0 * image - 1.0
51+
image = torch.from_numpy(image)
52+
elif isinstance(image[0], torch.Tensor):
53+
image = torch.cat(image, dim=0)
54+
return image
4255

4356

4457
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
@@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
7790
safety_checker: OnnxRuntimeModel
7891
feature_extractor: CLIPFeatureExtractor
7992

80-
_optional_components = ["safety_checker", "feature_extractor"]
93+
_optional_components = ["safety_checker"]
8194

8295
def __init__(
8396
self,
@@ -325,8 +338,7 @@ def __call__(
325338
# set timesteps
326339
self.scheduler.set_timesteps(num_inference_steps)
327340

328-
if isinstance(image, PIL.Image.Image):
329-
image = preprocess(image)
341+
image = preprocess(image)
330342

331343
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
332344
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
248248
return_tensors="pt",
249249
)
250250
text_input_ids = text_inputs.input_ids
251-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
251+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252252

253-
if not torch.equal(text_input_ids, untruncated_ids):
253+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
254254
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
255255
logger.warning(
256256
"The following part of your input was truncated because CLIP can only handle sequences up to"

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,26 @@
4141
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4242

4343

44+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
4445
def preprocess(image):
45-
w, h = image.size
46-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
47-
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
48-
image = np.array(image).astype(np.float32) / 255.0
49-
image = image[None].transpose(0, 3, 1, 2)
50-
image = torch.from_numpy(image)
51-
return 2.0 * image - 1.0
46+
if isinstance(image, torch.Tensor):
47+
return image
48+
elif isinstance(image, PIL.Image.Image):
49+
image = [image]
50+
51+
if isinstance(image[0], PIL.Image.Image):
52+
w, h = image[0].size
53+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
54+
55+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
56+
image = np.concatenate(image, axis=0)
57+
image = np.array(image).astype(np.float32) / 255.0
58+
image = image.transpose(0, 3, 1, 2)
59+
image = 2.0 * image - 1.0
60+
image = torch.from_numpy(image)
61+
elif isinstance(image[0], torch.Tensor):
62+
image = torch.cat(image, dim=0)
63+
return image
5264

5365

5466
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
@@ -189,9 +201,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
189201
return_tensors="pt",
190202
)
191203
text_input_ids = text_inputs.input_ids
192-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
204+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
193205

194-
if not torch.equal(text_input_ids, untruncated_ids):
206+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
195207
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
196208
logger.warning(
197209
"The following part of your input was truncated because CLIP can only handle sequences up to"
@@ -366,12 +378,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
366378

367379
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
368380
if isinstance(image, PIL.Image.Image):
369-
width, height = image.size
370-
width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32
371-
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
372-
width, height = image.size
381+
image = [image]
373382
else:
374383
image = [img for img in image]
384+
385+
if isinstance(image[0], PIL.Image.Image):
386+
width, height = image[0].size
387+
else:
375388
width, height = image[0].shape[-2:]
376389

377390
if depth_map is None:
@@ -493,7 +506,7 @@ def __call__(
493506
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
494507
)
495508

496-
# 4. Prepare depth mask
509+
# 4. Preprocess image
497510
depth_mask = self.prepare_depth_map(
498511
image,
499512
depth_map,
@@ -503,11 +516,8 @@ def __call__(
503516
device,
504517
)
505518

506-
# 5. Preprocess image
507-
if isinstance(image, PIL.Image.Image):
508-
image = preprocess(image)
509-
else:
510-
image = 2.0 * (image / 255.0) - 1.0
519+
# 5. Prepare depth mask
520+
image = preprocess(image)
511521

512522
# 6. set timesteps
513523
self.scheduler.set_timesteps(num_inference_steps, device=device)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
6565
feature_extractor ([`CLIPFeatureExtractor`]):
6666
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
6767
"""
68-
_optional_components = ["safety_checker", "feature_extractor"]
68+
_optional_components = ["safety_checker"]
6969

7070
def __init__(
7171
self,

0 commit comments

Comments
 (0)