Skip to content

Commit 07c0fe4

Browse files
Use pipeline tests mixin for UnCLIP pipeline tests + unCLIP MPS fixes (huggingface#1908)
re: huggingface#1857 We relax some of the checks to deal with unclip reproducibility issues. Mainly by checking the average pixel difference (measured w/in 0-255) instead of the max pixel difference (measured w/in 0-1). - [x] add mixin to UnCLIPPipelineFastTests - [x] add mixin to UnCLIPImageVariationPipelineFastTests - [x] Move UnCLIPPipeline flags in mixin to base class - [x] Small MPS fixes for F.pad and F.interpolate - [x] Made test unCLIP model's dimensions smaller to run tests faster
1 parent 1e651ca commit 07c0fe4

File tree

9 files changed

+380
-231
lines changed

9 files changed

+380
-231
lines changed

src/diffusers/models/cross_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,14 @@ def prepare_attention_mask(self, attention_mask, target_length):
208208
return attention_mask
209209

210210
if attention_mask.shape[-1] != target_length:
211-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
211+
if attention_mask.device.type == "mps":
212+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
213+
# Instead, we can manually construct the padding tensor.
214+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
215+
padding = torch.zeros(padding_shape, device=attention_mask.device)
216+
attention_mask = torch.concat([attention_mask, padding], dim=2)
217+
else:
218+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
212219
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
213220
return attention_mask
214221

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def __call__(
452452
eta (`float`, *optional*, defaults to 0.0):
453453
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
454454
[`schedulers.DDIMScheduler`], will be ignored for others.
455-
generator (`torch.Generator`, *optional*):
455+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
456456
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
457457
to make generation deterministic.
458458
latents (`torch.FloatTensor`, *optional*):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def __call__(
449449
eta (`float`, *optional*, defaults to 0.0):
450450
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
451451
[`schedulers.DDIMScheduler`], will be ignored for others.
452-
generator (`torch.Generator`, *optional*):
452+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
453453
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
454454
to make generation deterministic.
455455
latents (`torch.FloatTensor`, *optional*):

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
2323

2424
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
25-
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
25+
from ...pipelines import DiffusionPipeline
26+
from ...pipelines.pipeline_utils import ImagePipelineOutput
2627
from ...schedulers import UnCLIPScheduler
2728
from ...utils import is_accelerate_available, logging, randn_tensor
2829
from .text_proj import UnCLIPTextProjModel
@@ -130,13 +131,20 @@ def _encode_prompt(
130131
prompt,
131132
padding="max_length",
132133
max_length=self.tokenizer.model_max_length,
134+
truncation=True,
133135
return_tensors="pt",
134136
)
135137
text_input_ids = text_inputs.input_ids
136138
text_mask = text_inputs.attention_mask.bool().to(device)
137139

138-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
139-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
140+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
141+
142+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
143+
text_input_ids, untruncated_ids
144+
):
145+
removed_text = self.tokenizer.batch_decode(
146+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
147+
)
140148
logger.warning(
141149
"The following part of your input was truncated because CLIP can only handle sequences up to"
142150
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
@@ -249,7 +257,7 @@ def __call__(
249257
prior_num_inference_steps: int = 25,
250258
decoder_num_inference_steps: int = 25,
251259
super_res_num_inference_steps: int = 7,
252-
generator: Optional[torch.Generator] = None,
260+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
253261
prior_latents: Optional[torch.FloatTensor] = None,
254262
decoder_latents: Optional[torch.FloatTensor] = None,
255263
super_res_latents: Optional[torch.FloatTensor] = None,
@@ -278,7 +286,7 @@ def __call__(
278286
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
279287
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
280288
quality image at the expense of slower inference.
281-
generator (`torch.Generator`, *optional*):
289+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
282290
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
283291
to make generation deterministic.
284292
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
@@ -394,7 +402,14 @@ def __call__(
394402
do_classifier_free_guidance=do_classifier_free_guidance,
395403
)
396404

397-
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
405+
if device.type == "mps":
406+
# HACK: MPS: There is a panic when padding bool tensors,
407+
# so cast to int tensor for the pad and back to bool afterwards
408+
text_mask = text_mask.type(torch.int)
409+
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
410+
decoder_text_mask = decoder_text_mask.type(torch.bool)
411+
else:
412+
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
398413

399414
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
400415
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
@@ -465,13 +480,17 @@ def __call__(
465480
self.super_res_scheduler,
466481
)
467482

468-
interpolate_antialias = {}
469-
if "antialias" in inspect.signature(F.interpolate).parameters:
470-
interpolate_antialias["antialias"] = True
483+
if device.type == "mps":
484+
# MPS does not support many interpolations
485+
image_upscaled = F.interpolate(image_small, size=[height, width])
486+
else:
487+
interpolate_antialias = {}
488+
if "antialias" in inspect.signature(F.interpolate).parameters:
489+
interpolate_antialias["antialias"] = True
471490

472-
image_upscaled = F.interpolate(
473-
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
474-
)
491+
image_upscaled = F.interpolate(
492+
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
493+
)
475494

476495
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
477496
# no classifier free guidance

src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,14 @@ def __call__(
328328
do_classifier_free_guidance=do_classifier_free_guidance,
329329
)
330330

331-
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
331+
if device.type == "mps":
332+
# HACK: MPS: There is a panic when padding bool tensors,
333+
# so cast to int tensor for the pad and back to bool afterwards
334+
text_mask = text_mask.type(torch.int)
335+
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
336+
decoder_text_mask = decoder_text_mask.type(torch.bool)
337+
else:
338+
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
332339

333340
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
334341
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
@@ -401,13 +408,17 @@ def __call__(
401408
self.super_res_scheduler,
402409
)
403410

404-
interpolate_antialias = {}
405-
if "antialias" in inspect.signature(F.interpolate).parameters:
406-
interpolate_antialias["antialias"] = True
411+
if device.type == "mps":
412+
# MPS does not support many interpolations
413+
image_upscaled = F.interpolate(image_small, size=[height, width])
414+
else:
415+
interpolate_antialias = {}
416+
if "antialias" in inspect.signature(F.interpolate).parameters:
417+
interpolate_antialias["antialias"] = True
407418

408-
image_upscaled = F.interpolate(
409-
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
410-
)
419+
image_upscaled = F.interpolate(
420+
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
421+
)
411422

412423
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
413424
# no classifier free guidance

src/diffusers/schedulers/scheduling_unclip.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def step(
219219
returning a tuple, the first element is the sample tensor.
220220
221221
"""
222-
223222
t = timestep
224223

225224
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":

0 commit comments

Comments
 (0)