Skip to content

Commit 04d696d

Browse files
[Core] Add support for CLIP-skip (huggingface#4901)
* add support for clip skip * fix condition * fix * add clip_output_layer_to_default * expose * remove the previous functions. * correct condition. * apply final layer norm * address feedback * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * refactor clip_skip. * port to the other pipelines. * fix copies one more time --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent ed50768 commit 04d696d

28 files changed

+782
-186
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def _encode_prompt(
231231
prompt_embeds: Optional[torch.FloatTensor] = None,
232232
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
233233
lora_scale: Optional[float] = None,
234+
**kwargs,
234235
):
235236
deprecation_message = (
236237
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
@@ -247,6 +248,7 @@ def _encode_prompt(
247248
prompt_embeds=prompt_embeds,
248249
negative_prompt_embeds=negative_prompt_embeds,
249250
lora_scale=lora_scale,
251+
**kwargs,
250252
)
251253

252254
# concatenate for backwards comp
@@ -264,6 +266,7 @@ def encode_prompt(
264266
prompt_embeds: Optional[torch.FloatTensor] = None,
265267
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
266268
lora_scale: Optional[float] = None,
269+
clip_skip: Optional[int] = None,
267270
):
268271
r"""
269272
Encodes the prompt into text encoder hidden states.
@@ -289,7 +292,10 @@ def encode_prompt(
289292
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
290293
argument.
291294
lora_scale (`float`, *optional*):
292-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
295+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
296+
clip_skip (`int`, *optional*):
297+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
298+
the output of the pre-final layer will be used for computing the prompt embeddings.
293299
"""
294300
# set lora scale so that monkey patched LoRA
295301
# function of text encoder can correctly access it
@@ -337,11 +343,22 @@ def encode_prompt(
337343
else:
338344
attention_mask = None
339345

340-
prompt_embeds = self.text_encoder(
341-
text_input_ids.to(device),
342-
attention_mask=attention_mask,
343-
)
344-
prompt_embeds = prompt_embeds[0]
346+
if clip_skip is None:
347+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
348+
prompt_embeds = prompt_embeds[0]
349+
else:
350+
prompt_embeds = self.text_encoder(
351+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
352+
)
353+
# Access the `hidden_states` first, that contains a tuple of
354+
# all the hidden states from the encoder layers. Then index into
355+
# the tuple to access the hidden states from the desired layer.
356+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
357+
# We also need to apply the final LayerNorm here to not mess with the
358+
# representations. The `last_hidden_states` that we typically use for
359+
# obtaining the final prompt representations passes through the LayerNorm
360+
# layer.
361+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
345362

346363
if self.text_encoder is not None:
347364
prompt_embeds_dtype = self.text_encoder.dtype
@@ -544,6 +561,7 @@ def __call__(
544561
callback_steps: int = 1,
545562
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
546563
guidance_rescale: float = 0.0,
564+
clip_skip: Optional[int] = None,
547565
):
548566
r"""
549567
The call function to the pipeline for generation.
@@ -600,6 +618,9 @@ def __call__(
600618
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
601619
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
602620
using zero terminal SNR.
621+
clip_skip (`int`, *optional*):
622+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
623+
the output of the pre-final layer will be used for computing the prompt embeddings.
603624
604625
Examples:
605626
@@ -646,6 +667,7 @@ def __call__(
646667
prompt_embeds=prompt_embeds,
647668
negative_prompt_embeds=negative_prompt_embeds,
648669
lora_scale=text_encoder_lora_scale,
670+
clip_skip=clip_skip,
649671
)
650672
# For classifier free guidance, we need to do two forward passes.
651673
# Here we concatenate the unconditional and text embeddings into a single batch

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def _encode_prompt(
229229
prompt_embeds: Optional[torch.FloatTensor] = None,
230230
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
231231
lora_scale: Optional[float] = None,
232+
**kwargs,
232233
):
233234
deprecation_message = (
234235
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
@@ -245,6 +246,7 @@ def _encode_prompt(
245246
prompt_embeds=prompt_embeds,
246247
negative_prompt_embeds=negative_prompt_embeds,
247248
lora_scale=lora_scale,
249+
**kwargs,
248250
)
249251

250252
# concatenate for backwards comp
@@ -262,6 +264,7 @@ def encode_prompt(
262264
prompt_embeds: Optional[torch.FloatTensor] = None,
263265
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
264266
lora_scale: Optional[float] = None,
267+
clip_skip: Optional[int] = None,
265268
):
266269
r"""
267270
Encodes the prompt into text encoder hidden states.
@@ -287,7 +290,10 @@ def encode_prompt(
287290
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
288291
argument.
289292
lora_scale (`float`, *optional*):
290-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
293+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
294+
clip_skip (`int`, *optional*):
295+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
296+
the output of the pre-final layer will be used for computing the prompt embeddings.
291297
"""
292298
# set lora scale so that monkey patched LoRA
293299
# function of text encoder can correctly access it
@@ -335,11 +341,22 @@ def encode_prompt(
335341
else:
336342
attention_mask = None
337343

338-
prompt_embeds = self.text_encoder(
339-
text_input_ids.to(device),
340-
attention_mask=attention_mask,
341-
)
342-
prompt_embeds = prompt_embeds[0]
344+
if clip_skip is None:
345+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
346+
prompt_embeds = prompt_embeds[0]
347+
else:
348+
prompt_embeds = self.text_encoder(
349+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
350+
)
351+
# Access the `hidden_states` first, that contains a tuple of
352+
# all the hidden states from the encoder layers. Then index into
353+
# the tuple to access the hidden states from the desired layer.
354+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
355+
# We also need to apply the final LayerNorm here to not mess with the
356+
# representations. The `last_hidden_states` that we typically use for
357+
# obtaining the final prompt representations passes through the LayerNorm
358+
# layer.
359+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
343360

344361
if self.text_encoder is not None:
345362
prompt_embeds_dtype = self.text_encoder.dtype
@@ -582,6 +599,7 @@ def __call__(
582599
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
583600
callback_steps: int = 1,
584601
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
602+
clip_skip: int = None,
585603
):
586604
r"""
587605
The call function to the pipeline for generation.
@@ -638,7 +656,9 @@ def __call__(
638656
cross_attention_kwargs (`dict`, *optional*):
639657
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
640658
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
641-
659+
clip_skip (`int`, *optional*):
660+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
661+
the output of the pre-final layer will be used for computing the prompt embeddings.
642662
Examples:
643663
644664
Returns:
@@ -677,6 +697,7 @@ def __call__(
677697
prompt_embeds=prompt_embeds,
678698
negative_prompt_embeds=negative_prompt_embeds,
679699
lora_scale=text_encoder_lora_scale,
700+
clip_skip=clip_skip,
680701
)
681702
# For classifier free guidance, we need to do two forward passes.
682703
# Here we concatenate the unconditional and text embeddings into a single batch

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def _encode_prompt(
221221
prompt_embeds: Optional[torch.FloatTensor] = None,
222222
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
223223
lora_scale: Optional[float] = None,
224+
**kwargs,
224225
):
225226
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
226227
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
@@ -234,6 +235,7 @@ def _encode_prompt(
234235
prompt_embeds=prompt_embeds,
235236
negative_prompt_embeds=negative_prompt_embeds,
236237
lora_scale=lora_scale,
238+
**kwargs,
237239
)
238240

239241
# concatenate for backwards comp
@@ -252,6 +254,7 @@ def encode_prompt(
252254
prompt_embeds: Optional[torch.FloatTensor] = None,
253255
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
254256
lora_scale: Optional[float] = None,
257+
clip_skip: Optional[int] = None,
255258
):
256259
r"""
257260
Encodes the prompt into text encoder hidden states.
@@ -277,7 +280,10 @@ def encode_prompt(
277280
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
278281
argument.
279282
lora_scale (`float`, *optional*):
280-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
283+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
284+
clip_skip (`int`, *optional*):
285+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
286+
the output of the pre-final layer will be used for computing the prompt embeddings.
281287
"""
282288
# set lora scale so that monkey patched LoRA
283289
# function of text encoder can correctly access it
@@ -325,11 +331,22 @@ def encode_prompt(
325331
else:
326332
attention_mask = None
327333

328-
prompt_embeds = self.text_encoder(
329-
text_input_ids.to(device),
330-
attention_mask=attention_mask,
331-
)
332-
prompt_embeds = prompt_embeds[0]
334+
if clip_skip is None:
335+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
336+
prompt_embeds = prompt_embeds[0]
337+
else:
338+
prompt_embeds = self.text_encoder(
339+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
340+
)
341+
# Access the `hidden_states` first, that contains a tuple of
342+
# all the hidden states from the encoder layers. Then index into
343+
# the tuple to access the hidden states from the desired layer.
344+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
345+
# We also need to apply the final LayerNorm here to not mess with the
346+
# representations. The `last_hidden_states` that we typically use for
347+
# obtaining the final prompt representations passes through the LayerNorm
348+
# layer.
349+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
333350

334351
if self.text_encoder is not None:
335352
prompt_embeds_dtype = self.text_encoder.dtype
@@ -697,6 +714,7 @@ def __call__(
697714
guess_mode: bool = False,
698715
control_guidance_start: Union[float, List[float]] = 0.0,
699716
control_guidance_end: Union[float, List[float]] = 1.0,
717+
clip_skip: Optional[int] = None,
700718
):
701719
r"""
702720
The call function to the pipeline for generation.
@@ -768,6 +786,9 @@ def __call__(
768786
The percentage of total steps at which the ControlNet starts applying.
769787
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
770788
The percentage of total steps at which the ControlNet stops applying.
789+
clip_skip (`int`, *optional*):
790+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
791+
the output of the pre-final layer will be used for computing the prompt embeddings.
771792
772793
Examples:
773794
@@ -841,6 +862,7 @@ def __call__(
841862
prompt_embeds=prompt_embeds,
842863
negative_prompt_embeds=negative_prompt_embeds,
843864
lora_scale=text_encoder_lora_scale,
865+
clip_skip=clip_skip,
844866
)
845867
# For classifier free guidance, we need to do two forward passes.
846868
# Here we concatenate the unconditional and text embeddings into a single batch

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def _encode_prompt(
245245
prompt_embeds: Optional[torch.FloatTensor] = None,
246246
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
247247
lora_scale: Optional[float] = None,
248+
**kwargs,
248249
):
249250
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
250251
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
@@ -258,6 +259,7 @@ def _encode_prompt(
258259
prompt_embeds=prompt_embeds,
259260
negative_prompt_embeds=negative_prompt_embeds,
260261
lora_scale=lora_scale,
262+
**kwargs,
261263
)
262264

263265
# concatenate for backwards comp
@@ -276,6 +278,7 @@ def encode_prompt(
276278
prompt_embeds: Optional[torch.FloatTensor] = None,
277279
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
278280
lora_scale: Optional[float] = None,
281+
clip_skip: Optional[int] = None,
279282
):
280283
r"""
281284
Encodes the prompt into text encoder hidden states.
@@ -301,7 +304,10 @@ def encode_prompt(
301304
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302305
argument.
303306
lora_scale (`float`, *optional*):
304-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
307+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
308+
clip_skip (`int`, *optional*):
309+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
310+
the output of the pre-final layer will be used for computing the prompt embeddings.
305311
"""
306312
# set lora scale so that monkey patched LoRA
307313
# function of text encoder can correctly access it
@@ -349,11 +355,22 @@ def encode_prompt(
349355
else:
350356
attention_mask = None
351357

352-
prompt_embeds = self.text_encoder(
353-
text_input_ids.to(device),
354-
attention_mask=attention_mask,
355-
)
356-
prompt_embeds = prompt_embeds[0]
358+
if clip_skip is None:
359+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
360+
prompt_embeds = prompt_embeds[0]
361+
else:
362+
prompt_embeds = self.text_encoder(
363+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
364+
)
365+
# Access the `hidden_states` first, that contains a tuple of
366+
# all the hidden states from the encoder layers. Then index into
367+
# the tuple to access the hidden states from the desired layer.
368+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
369+
# We also need to apply the final LayerNorm here to not mess with the
370+
# representations. The `last_hidden_states` that we typically use for
371+
# obtaining the final prompt representations passes through the LayerNorm
372+
# layer.
373+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
357374

358375
if self.text_encoder is not None:
359376
prompt_embeds_dtype = self.text_encoder.dtype
@@ -769,6 +786,7 @@ def __call__(
769786
guess_mode: bool = False,
770787
control_guidance_start: Union[float, List[float]] = 0.0,
771788
control_guidance_end: Union[float, List[float]] = 1.0,
789+
clip_skip: Optional[int] = None,
772790
):
773791
r"""
774792
The call function to the pipeline for generation.
@@ -844,6 +862,9 @@ def __call__(
844862
The percentage of total steps at which the ControlNet starts applying.
845863
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
846864
The percentage of total steps at which the ControlNet stops applying.
865+
clip_skip (`int`, *optional*):
866+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
867+
the output of the pre-final layer will be used for computing the prompt embeddings.
847868
848869
Examples:
849870
@@ -917,6 +938,7 @@ def __call__(
917938
prompt_embeds=prompt_embeds,
918939
negative_prompt_embeds=negative_prompt_embeds,
919940
lora_scale=text_encoder_lora_scale,
941+
clip_skip=clip_skip,
920942
)
921943
# For classifier free guidance, we need to do two forward passes.
922944
# Here we concatenate the unconditional and text embeddings into a single batch

0 commit comments

Comments
 (0)