Skip to content

Commit 5d28d22

Browse files
authored
[Wuerstchen] fix combined pipeline's num_images_per_prompt (huggingface#4989)
* fix encode_prompt * added prompt_embeds and negative_prompt_embeds * prompt_embeds for the prior only
1 parent 73bf620 commit 5d28d22

File tree

3 files changed

+84
-38
lines changed

3 files changed

+84
-38
lines changed

src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,11 @@ def __call__(
330330

331331
# 2. Encode caption
332332
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
333-
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
333+
prompt,
334+
device,
335+
image_embeddings.size(0) * num_images_per_prompt,
336+
do_classifier_free_guidance,
337+
negative_prompt,
334338
)
335339
text_encoder_hidden_states = (
336340
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds

src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def __call__(
154154
decoder_timesteps: Optional[List[float]] = None,
155155
decoder_guidance_scale: float = 0.0,
156156
negative_prompt: Optional[Union[str, List[str]]] = None,
157+
prompt_embeds: Optional[torch.FloatTensor] = None,
158+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
157159
num_images_per_prompt: int = 1,
158160
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
159161
latents: Optional[torch.FloatTensor] = None,
@@ -165,10 +167,17 @@ def __call__(
165167
166168
Args:
167169
prompt (`str` or `List[str]`):
168-
The prompt or prompts to guide the image generation.
170+
The prompt or prompts to guide the image generation for the prior and decoder.
169171
negative_prompt (`str` or `List[str]`, *optional*):
170172
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
171173
if `guidance_scale` is less than `1`).
174+
prompt_embeds (`torch.FloatTensor`, *optional*):
175+
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
176+
provided, text embeddings will be generated from `prompt` input argument.
177+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
178+
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
179+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
180+
argument.
172181
num_images_per_prompt (`int`, *optional*, defaults to 1):
173182
The number of images to generate per prompt.
174183
height (`int`, *optional*, defaults to 512):
@@ -221,13 +230,15 @@ def __call__(
221230
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
222231
"""
223232
prior_outputs = self.prior_pipe(
224-
prompt=prompt,
233+
prompt=prompt if prompt_embeds is None else None,
225234
height=height,
226235
width=width,
227236
num_inference_steps=prior_num_inference_steps,
228237
timesteps=prior_timesteps,
229238
guidance_scale=prior_guidance_scale,
230-
negative_prompt=negative_prompt,
239+
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
240+
prompt_embeds=prompt_embeds,
241+
negative_prompt_embeds=negative_prompt_embeds,
231242
num_images_per_prompt=num_images_per_prompt,
232243
generator=generator,
233244
latents=latents,

src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -150,41 +150,57 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
150150

151151
def encode_prompt(
152152
self,
153-
prompt,
154153
device,
155154
num_images_per_prompt,
156155
do_classifier_free_guidance,
156+
prompt=None,
157157
negative_prompt=None,
158+
prompt_embeds: Optional[torch.FloatTensor] = None,
159+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
158160
):
159-
batch_size = len(prompt) if isinstance(prompt, list) else 1
160-
# get prompt text embeddings
161-
text_inputs = self.tokenizer(
162-
prompt,
163-
padding="max_length",
164-
max_length=self.tokenizer.model_max_length,
165-
truncation=True,
166-
return_tensors="pt",
167-
)
168-
text_input_ids = text_inputs.input_ids
169-
attention_mask = text_inputs.attention_mask
161+
if prompt is not None and isinstance(prompt, str):
162+
batch_size = 1
163+
elif prompt is not None and isinstance(prompt, list):
164+
batch_size = len(prompt)
165+
else:
166+
batch_size = prompt_embeds.shape[0]
167+
168+
if prompt_embeds is None:
169+
# get prompt text embeddings
170+
text_inputs = self.tokenizer(
171+
prompt,
172+
padding="max_length",
173+
max_length=self.tokenizer.model_max_length,
174+
truncation=True,
175+
return_tensors="pt",
176+
)
177+
text_input_ids = text_inputs.input_ids
178+
attention_mask = text_inputs.attention_mask
170179

171-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
180+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
181+
182+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
183+
text_input_ids, untruncated_ids
184+
):
185+
removed_text = self.tokenizer.batch_decode(
186+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
187+
)
188+
logger.warning(
189+
"The following part of your input was truncated because CLIP can only handle sequences up to"
190+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
191+
)
192+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
193+
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
172194

173-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
174-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
175-
logger.warning(
176-
"The following part of your input was truncated because CLIP can only handle sequences up to"
177-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
195+
text_encoder_output = self.text_encoder(
196+
text_input_ids.to(device), attention_mask=attention_mask.to(device)
178197
)
179-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
180-
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
198+
prompt_embeds = text_encoder_output.last_hidden_state
181199

182-
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
183-
text_encoder_hidden_states = text_encoder_output.last_hidden_state
184-
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
200+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
201+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
185202

186-
uncond_text_encoder_hidden_states = None
187-
if do_classifier_free_guidance:
203+
if negative_prompt_embeds is None and do_classifier_free_guidance:
188204
uncond_tokens: List[str]
189205
if negative_prompt is None:
190206
uncond_tokens = [""] * batch_size
@@ -215,17 +231,17 @@ def encode_prompt(
215231
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
216232
)
217233

218-
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
234+
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state
219235

236+
if do_classifier_free_guidance:
220237
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
221-
seq_len = uncond_text_encoder_hidden_states.shape[1]
222-
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
223-
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
224-
batch_size * num_images_per_prompt, seq_len, -1
225-
)
238+
seq_len = negative_prompt_embeds.shape[1]
239+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
240+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
241+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
226242
# done duplicates
227243

228-
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
244+
return prompt_embeds, negative_prompt_embeds
229245

230246
def check_inputs(
231247
self,
@@ -264,13 +280,15 @@ def check_inputs(
264280
@replace_example_docstring(EXAMPLE_DOC_STRING)
265281
def __call__(
266282
self,
267-
prompt: Union[str, List[str]] = None,
283+
prompt: Optional[Union[str, List[str]]] = None,
268284
height: int = 1024,
269285
width: int = 1024,
270286
num_inference_steps: int = 60,
271287
timesteps: List[float] = None,
272288
guidance_scale: float = 8.0,
273289
negative_prompt: Optional[Union[str, List[str]]] = None,
290+
prompt_embeds: Optional[torch.FloatTensor] = None,
291+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
274292
num_images_per_prompt: Optional[int] = 1,
275293
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
276294
latents: Optional[torch.FloatTensor] = None,
@@ -304,6 +322,13 @@ def __call__(
304322
negative_prompt (`str` or `List[str]`, *optional*):
305323
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
306324
if `decoder_guidance_scale` is less than `1`).
325+
prompt_embeds (`torch.FloatTensor`, *optional*):
326+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
327+
provided, text embeddings will be generated from `prompt` input argument.
328+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
329+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
330+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
331+
argument.
307332
num_images_per_prompt (`int`, *optional*, defaults to 1):
308333
The number of images to generate per prompt.
309334
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -345,7 +370,13 @@ def __call__(
345370

346371
# 2. Encode caption
347372
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
348-
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
373+
prompt=prompt,
374+
device=device,
375+
num_images_per_prompt=num_images_per_prompt,
376+
do_classifier_free_guidance=do_classifier_free_guidance,
377+
negative_prompt=negative_prompt,
378+
prompt_embeds=prompt_embeds,
379+
negative_prompt_embeds=negative_prompt_embeds,
349380
)
350381

351382
# For classifier free guidance, we need to do two forward passes.

0 commit comments

Comments
 (0)