Skip to content

Commit 17528af

Browse files
Fix styling issues (huggingface#5699)
* up * up * up * Empty-Commit * fix keyword argument call. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 78be400 commit 17528af

File tree

3 files changed

+24
-34
lines changed

3 files changed

+24
-34
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,15 @@ def _encode_prior_prompt(
206206
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
207207

208208
prompt_embeds = prior_text_encoder_output.text_embeds
209-
prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
209+
text_enc_hid_states = prior_text_encoder_output.last_hidden_state
210210

211211
else:
212212
batch_size = text_model_output[0].shape[0]
213-
prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
213+
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
214214
text_mask = text_attention_mask
215215

216216
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
217-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
218-
num_images_per_prompt, dim=0
219-
)
217+
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
220218
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
221219

222220
if do_classifier_free_guidance:
@@ -235,21 +233,17 @@ def _encode_prior_prompt(
235233
)
236234

237235
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
238-
uncond_prior_text_encoder_hidden_states = (
239-
negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
240-
)
236+
uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
241237

242238
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
243239

244240
seq_len = negative_prompt_embeds.shape[1]
245241
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
246242
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
247243

248-
seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
249-
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
250-
1, num_images_per_prompt, 1
251-
)
252-
uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
244+
seq_len = uncond_text_enc_hid_states.shape[1]
245+
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
246+
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
253247
batch_size * num_images_per_prompt, seq_len, -1
254248
)
255249
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -260,13 +254,11 @@ def _encode_prior_prompt(
260254
# Here we concatenate the unconditional and text embeddings into a single batch
261255
# to avoid doing two forward passes
262256
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
263-
prior_text_encoder_hidden_states = torch.cat(
264-
[uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
265-
)
257+
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
266258

267259
text_mask = torch.cat([uncond_text_mask, text_mask])
268260

269-
return prompt_embeds, prior_text_encoder_hidden_states, text_mask
261+
return prompt_embeds, text_enc_hid_states, text_mask
270262

271263
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
272264
def _encode_prompt(

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,15 @@ def _encode_prompt(
156156
text_encoder_output = self.text_encoder(text_input_ids.to(device))
157157

158158
prompt_embeds = text_encoder_output.text_embeds
159-
text_encoder_hidden_states = text_encoder_output.last_hidden_state
159+
text_enc_hid_states = text_encoder_output.last_hidden_state
160160

161161
else:
162162
batch_size = text_model_output[0].shape[0]
163-
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
163+
prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
164164
text_mask = text_attention_mask
165165

166166
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
167-
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
167+
text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
168168
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
169169

170170
if do_classifier_free_guidance:
@@ -181,17 +181,17 @@ def _encode_prompt(
181181
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
182182

183183
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
184-
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
184+
uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
185185

186186
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
187187

188188
seq_len = negative_prompt_embeds.shape[1]
189189
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
190190
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
191191

192-
seq_len = uncond_text_encoder_hidden_states.shape[1]
193-
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
194-
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
192+
seq_len = uncond_text_enc_hid_states.shape[1]
193+
uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
194+
uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
195195
batch_size * num_images_per_prompt, seq_len, -1
196196
)
197197
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ def _encode_prompt(
202202
# Here we concatenate the unconditional and text embeddings into a single batch
203203
# to avoid doing two forward passes
204204
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
205-
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
205+
text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
206206

207207
text_mask = torch.cat([uncond_text_mask, text_mask])
208208

209-
return prompt_embeds, text_encoder_hidden_states, text_mask
209+
return prompt_embeds, text_enc_hid_states, text_mask
210210

211211
@torch.no_grad()
212212
def __call__(
@@ -293,7 +293,7 @@ def __call__(
293293

294294
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
295295

296-
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
296+
prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
297297
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
298298
)
299299

@@ -321,7 +321,7 @@ def __call__(
321321
latent_model_input,
322322
timestep=t,
323323
proj_embedding=prompt_embeds,
324-
encoder_hidden_states=text_encoder_hidden_states,
324+
encoder_hidden_states=text_enc_hid_states,
325325
attention_mask=text_mask,
326326
).predicted_image_embedding
327327

@@ -352,10 +352,10 @@ def __call__(
352352

353353
# decoder
354354

355-
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
355+
text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
356356
image_embeddings=image_embeddings,
357357
prompt_embeds=prompt_embeds,
358-
text_encoder_hidden_states=text_encoder_hidden_states,
358+
text_encoder_hidden_states=text_enc_hid_states,
359359
do_classifier_free_guidance=do_classifier_free_guidance,
360360
)
361361

@@ -377,7 +377,7 @@ def __call__(
377377

378378
decoder_latents = self.prepare_latents(
379379
(batch_size, num_channels_latents, height, width),
380-
text_encoder_hidden_states.dtype,
380+
text_enc_hid_states.dtype,
381381
device,
382382
generator,
383383
decoder_latents,
@@ -391,7 +391,7 @@ def __call__(
391391
noise_pred = self.decoder(
392392
sample=latent_model_input,
393393
timestep=t,
394-
encoder_hidden_states=text_encoder_hidden_states,
394+
encoder_hidden_states=text_enc_hid_states,
395395
class_labels=additive_clip_time_embeddings,
396396
attention_mask=decoder_text_mask,
397397
).sample

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,7 +1494,6 @@ def forward(self, input_tensor, temb):
14941494
return output_tensor
14951495

14961496

1497-
# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
14981497
class DownBlockFlat(nn.Module):
14991498
def __init__(
15001499
self,
@@ -1583,7 +1582,6 @@ def custom_forward(*inputs):
15831582
return hidden_states, output_states
15841583

15851584

1586-
# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
15871585
class CrossAttnDownBlockFlat(nn.Module):
15881586
def __init__(
15891587
self,

0 commit comments

Comments
 (0)