Skip to content

Commit edcbb6f

Browse files
authored
[WIP] core: add support for clip skip to SDXL (huggingface#5057)
* core: add support for clip ckip to SDXL * add clip_skip support to the rest of the pipeline. * Empty-Commit
1 parent 5a287d3 commit edcbb6f

File tree

7 files changed

+103
-35
lines changed

7 files changed

+103
-35
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def encode_prompt(
263263
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
264264
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
265265
lora_scale: Optional[float] = None,
266+
clip_skip: Optional[int] = None,
266267
):
267268
r"""
268269
Encodes the prompt into text encoder hidden states.
@@ -302,6 +303,9 @@ def encode_prompt(
302303
input argument.
303304
lora_scale (`float`, *optional*):
304305
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
306+
clip_skip (`int`, *optional*):
307+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
308+
the output of the pre-final layer will be used for computing the prompt embeddings.
305309
"""
306310
device = device or self._execution_device
307311

@@ -358,14 +362,15 @@ def encode_prompt(
358362
f" {tokenizer.model_max_length} tokens: {removed_text}"
359363
)
360364

361-
prompt_embeds = text_encoder(
362-
text_input_ids.to(device),
363-
output_hidden_states=True,
364-
)
365+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
365366

366367
# We are only ALWAYS interested in the pooled output of the final text encoder
367368
pooled_prompt_embeds = prompt_embeds[0]
368-
prompt_embeds = prompt_embeds.hidden_states[-2]
369+
if clip_skip is None:
370+
prompt_embeds = prompt_embeds.hidden_states[-2]
371+
else:
372+
# "2" because SDXL always indexes from the penultimate layer.
373+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
369374

370375
prompt_embeds_list.append(prompt_embeds)
371376

@@ -971,6 +976,7 @@ def __call__(
971976
target_size: Tuple[int, int] = None,
972977
aesthetic_score: float = 6.0,
973978
negative_aesthetic_score: float = 2.5,
979+
clip_skip: Optional[int] = None,
974980
):
975981
r"""
976982
Function invoked when calling the pipeline for generation.
@@ -1097,6 +1103,9 @@ def __call__(
10971103
Part of SDXL's micro-conditioning as explained in section 2.2 of
10981104
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
10991105
simulate an aesthetic score of the generated image by influencing the negative text condition.
1106+
clip_skip (`int`, *optional*):
1107+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1108+
the output of the pre-final layer will be used for computing the prompt embeddings.
11001109
11011110
Examples:
11021111
@@ -1192,6 +1201,7 @@ def __call__(
11921201
pooled_prompt_embeds=pooled_prompt_embeds,
11931202
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
11941203
lora_scale=text_encoder_lora_scale,
1204+
clip_skip=clip_skip,
11951205
)
11961206

11971207
# 4. set timesteps

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def encode_prompt(
236236
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
237237
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
238238
lora_scale: Optional[float] = None,
239+
clip_skip: Optional[int] = None,
239240
):
240241
r"""
241242
Encodes the prompt into text encoder hidden states.
@@ -275,6 +276,9 @@ def encode_prompt(
275276
input argument.
276277
lora_scale (`float`, *optional*):
277278
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
279+
clip_skip (`int`, *optional*):
280+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
281+
the output of the pre-final layer will be used for computing the prompt embeddings.
278282
"""
279283
device = device or self._execution_device
280284

@@ -331,14 +335,15 @@ def encode_prompt(
331335
f" {tokenizer.model_max_length} tokens: {removed_text}"
332336
)
333337

334-
prompt_embeds = text_encoder(
335-
text_input_ids.to(device),
336-
output_hidden_states=True,
337-
)
338+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
338339

339340
# We are only ALWAYS interested in the pooled output of the final text encoder
340341
pooled_prompt_embeds = prompt_embeds[0]
341-
prompt_embeds = prompt_embeds.hidden_states[-2]
342+
if clip_skip is None:
343+
prompt_embeds = prompt_embeds.hidden_states[-2]
344+
else:
345+
# "2" because SDXL always indexes from the penultimate layer.
346+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
342347

343348
prompt_embeds_list.append(prompt_embeds)
344349

@@ -767,6 +772,7 @@ def __call__(
767772
negative_original_size: Optional[Tuple[int, int]] = None,
768773
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
769774
negative_target_size: Optional[Tuple[int, int]] = None,
775+
clip_skip: Optional[int] = None,
770776
):
771777
r"""
772778
The call function to the pipeline for generation.
@@ -884,6 +890,9 @@ def __call__(
884890
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
885891
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
886892
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
893+
clip_skip (`int`, *optional*):
894+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
895+
the output of the pre-final layer will be used for computing the prompt embeddings.
887896
888897
Examples:
889898
@@ -968,6 +977,7 @@ def __call__(
968977
pooled_prompt_embeds=pooled_prompt_embeds,
969978
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
970979
lora_scale=text_encoder_lora_scale,
980+
clip_skip=clip_skip,
971981
)
972982

973983
# 4. Prepare image

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def encode_prompt(
274274
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
275275
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
276276
lora_scale: Optional[float] = None,
277+
clip_skip: Optional[int] = None,
277278
):
278279
r"""
279280
Encodes the prompt into text encoder hidden states.
@@ -313,6 +314,9 @@ def encode_prompt(
313314
input argument.
314315
lora_scale (`float`, *optional*):
315316
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
317+
clip_skip (`int`, *optional*):
318+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
319+
the output of the pre-final layer will be used for computing the prompt embeddings.
316320
"""
317321
device = device or self._execution_device
318322

@@ -369,14 +373,15 @@ def encode_prompt(
369373
f" {tokenizer.model_max_length} tokens: {removed_text}"
370374
)
371375

372-
prompt_embeds = text_encoder(
373-
text_input_ids.to(device),
374-
output_hidden_states=True,
375-
)
376+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
376377

377378
# We are only ALWAYS interested in the pooled output of the final text encoder
378379
pooled_prompt_embeds = prompt_embeds[0]
379-
prompt_embeds = prompt_embeds.hidden_states[-2]
380+
if clip_skip is None:
381+
prompt_embeds = prompt_embeds.hidden_states[-2]
382+
else:
383+
# "2" because SDXL always indexes from the penultimate layer.
384+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
380385

381386
prompt_embeds_list.append(prompt_embeds)
382387

@@ -914,6 +919,7 @@ def __call__(
914919
negative_target_size: Optional[Tuple[int, int]] = None,
915920
aesthetic_score: float = 6.0,
916921
negative_aesthetic_score: float = 2.5,
922+
clip_skip: Optional[int] = None,
917923
):
918924
r"""
919925
Function invoked when calling the pipeline for generation.
@@ -1057,6 +1063,9 @@ def __call__(
10571063
Part of SDXL's micro-conditioning as explained in section 2.2 of
10581064
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
10591065
simulate an aesthetic score of the generated image by influencing the negative text condition.
1066+
clip_skip (`int`, *optional*):
1067+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1068+
the output of the pre-final layer will be used for computing the prompt embeddings.
10601069
10611070
Examples:
10621071
@@ -1143,6 +1152,7 @@ def __call__(
11431152
pooled_prompt_embeds=pooled_prompt_embeds,
11441153
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
11451154
lora_scale=text_encoder_lora_scale,
1155+
clip_skip=clip_skip,
11461156
)
11471157

11481158
# 4. Prepare image and controlnet_conditioning_image

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def encode_prompt(
212212
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
213213
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
214214
lora_scale: Optional[float] = None,
215+
clip_skip: Optional[int] = None,
215216
):
216217
r"""
217218
Encodes the prompt into text encoder hidden states.
@@ -251,6 +252,9 @@ def encode_prompt(
251252
input argument.
252253
lora_scale (`float`, *optional*):
253254
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
255+
clip_skip (`int`, *optional*):
256+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
257+
the output of the pre-final layer will be used for computing the prompt embeddings.
254258
"""
255259
device = device or self._execution_device
256260

@@ -307,14 +311,15 @@ def encode_prompt(
307311
f" {tokenizer.model_max_length} tokens: {removed_text}"
308312
)
309313

310-
prompt_embeds = text_encoder(
311-
text_input_ids.to(device),
312-
output_hidden_states=True,
313-
)
314+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
314315

315316
# We are only ALWAYS interested in the pooled output of the final text encoder
316317
pooled_prompt_embeds = prompt_embeds[0]
317-
prompt_embeds = prompt_embeds.hidden_states[-2]
318+
if clip_skip is None:
319+
prompt_embeds = prompt_embeds.hidden_states[-2]
320+
else:
321+
# "2" because SDXL always indexes from the penultimate layer.
322+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
318323

319324
prompt_embeds_list.append(prompt_embeds)
320325

@@ -577,6 +582,7 @@ def __call__(
577582
negative_original_size: Optional[Tuple[int, int]] = None,
578583
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
579584
negative_target_size: Optional[Tuple[int, int]] = None,
585+
clip_skip: Optional[int] = None,
580586
):
581587
r"""
582588
Function invoked when calling the pipeline for generation.
@@ -764,6 +770,7 @@ def __call__(
764770
pooled_prompt_embeds=pooled_prompt_embeds,
765771
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
766772
lora_scale=text_encoder_lora_scale,
773+
clip_skip=clip_skip,
767774
)
768775

769776
# 4. Prepare timesteps

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def encode_prompt(
219219
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
220220
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
221221
lora_scale: Optional[float] = None,
222+
clip_skip: Optional[int] = None,
222223
):
223224
r"""
224225
Encodes the prompt into text encoder hidden states.
@@ -258,6 +259,9 @@ def encode_prompt(
258259
input argument.
259260
lora_scale (`float`, *optional*):
260261
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
262+
clip_skip (`int`, *optional*):
263+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
264+
the output of the pre-final layer will be used for computing the prompt embeddings.
261265
"""
262266
device = device or self._execution_device
263267

@@ -314,14 +318,15 @@ def encode_prompt(
314318
f" {tokenizer.model_max_length} tokens: {removed_text}"
315319
)
316320

317-
prompt_embeds = text_encoder(
318-
text_input_ids.to(device),
319-
output_hidden_states=True,
320-
)
321+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
321322

322323
# We are only ALWAYS interested in the pooled output of the final text encoder
323324
pooled_prompt_embeds = prompt_embeds[0]
324-
prompt_embeds = prompt_embeds.hidden_states[-2]
325+
if clip_skip is None:
326+
prompt_embeds = prompt_embeds.hidden_states[-2]
327+
else:
328+
# "2" because SDXL always indexes from the penultimate layer.
329+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
325330

326331
prompt_embeds_list.append(prompt_embeds)
327332

@@ -688,6 +693,7 @@ def __call__(
688693
negative_target_size: Optional[Tuple[int, int]] = None,
689694
aesthetic_score: float = 6.0,
690695
negative_aesthetic_score: float = 2.5,
696+
clip_skip: Optional[int] = None,
691697
):
692698
r"""
693699
Function invoked when calling the pipeline for generation.
@@ -823,6 +829,9 @@ def __call__(
823829
Part of SDXL's micro-conditioning as explained in section 2.2 of
824830
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
825831
simulate an aesthetic score of the generated image by influencing the negative text condition.
832+
clip_skip (`int`, *optional*):
833+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
834+
the output of the pre-final layer will be used for computing the prompt embeddings.
826835
827836
Examples:
828837
@@ -881,6 +890,7 @@ def __call__(
881890
pooled_prompt_embeds=pooled_prompt_embeds,
882891
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
883892
lora_scale=text_encoder_lora_scale,
893+
clip_skip=clip_skip,
884894
)
885895

886896
# 4. Preprocess image

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def encode_prompt(
368368
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
369369
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
370370
lora_scale: Optional[float] = None,
371+
clip_skip: Optional[int] = None,
371372
):
372373
r"""
373374
Encodes the prompt into text encoder hidden states.
@@ -407,6 +408,9 @@ def encode_prompt(
407408
input argument.
408409
lora_scale (`float`, *optional*):
409410
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
411+
clip_skip (`int`, *optional*):
412+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
413+
the output of the pre-final layer will be used for computing the prompt embeddings.
410414
"""
411415
device = device or self._execution_device
412416

@@ -463,14 +467,15 @@ def encode_prompt(
463467
f" {tokenizer.model_max_length} tokens: {removed_text}"
464468
)
465469

466-
prompt_embeds = text_encoder(
467-
text_input_ids.to(device),
468-
output_hidden_states=True,
469-
)
470+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
470471

471472
# We are only ALWAYS interested in the pooled output of the final text encoder
472473
pooled_prompt_embeds = prompt_embeds[0]
473-
prompt_embeds = prompt_embeds.hidden_states[-2]
474+
if clip_skip is None:
475+
prompt_embeds = prompt_embeds.hidden_states[-2]
476+
else:
477+
# "2" because SDXL always indexes from the penultimate layer.
478+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
474479

475480
prompt_embeds_list.append(prompt_embeds)
476481

@@ -910,6 +915,7 @@ def __call__(
910915
negative_target_size: Optional[Tuple[int, int]] = None,
911916
aesthetic_score: float = 6.0,
912917
negative_aesthetic_score: float = 2.5,
918+
clip_skip: Optional[int] = None,
913919
):
914920
r"""
915921
Function invoked when calling the pipeline for generation.
@@ -1057,6 +1063,9 @@ def __call__(
10571063
Part of SDXL's micro-conditioning as explained in section 2.2 of
10581064
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
10591065
simulate an aesthetic score of the generated image by influencing the negative text condition.
1066+
clip_skip (`int`, *optional*):
1067+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1068+
the output of the pre-final layer will be used for computing the prompt embeddings.
10601069
10611070
Examples:
10621071
@@ -1120,6 +1129,7 @@ def __call__(
11201129
pooled_prompt_embeds=pooled_prompt_embeds,
11211130
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
11221131
lora_scale=text_encoder_lora_scale,
1132+
clip_skip=clip_skip,
11231133
)
11241134

11251135
# 4. set timesteps

0 commit comments

Comments
 (0)