Skip to content

Commit cc92332

Browse files
[PEFT / LoRA ] Fix text encoder scaling (huggingface#5204)
* move text encoder changes * fix * add comment. * fix tests * Update src/diffusers/utils/peft_utils.py --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9cfd4ef commit cc92332

38 files changed

+385
-125
lines changed

src/diffusers/models/lora.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,24 @@
1919
from torch import nn
2020

2121
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22-
from ..utils import logging, scale_lora_layers
22+
from ..utils import logging
2323

2424

2525
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2626

2727

28-
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
29-
if use_peft_backend:
30-
scale_lora_layers(text_encoder, weight=lora_scale)
31-
else:
32-
for _, attn_module in text_encoder_attn_modules(text_encoder):
33-
if isinstance(attn_module.q_proj, PatchedLoraProjection):
34-
attn_module.q_proj.lora_scale = lora_scale
35-
attn_module.k_proj.lora_scale = lora_scale
36-
attn_module.v_proj.lora_scale = lora_scale
37-
attn_module.out_proj.lora_scale = lora_scale
38-
39-
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
40-
if isinstance(mlp_module.fc1, PatchedLoraProjection):
41-
mlp_module.fc1.lora_scale = lora_scale
42-
mlp_module.fc2.lora_scale = lora_scale
28+
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
29+
for _, attn_module in text_encoder_attn_modules(text_encoder):
30+
if isinstance(attn_module.q_proj, PatchedLoraProjection):
31+
attn_module.q_proj.lora_scale = lora_scale
32+
attn_module.k_proj.lora_scale = lora_scale
33+
attn_module.v_proj.lora_scale = lora_scale
34+
attn_module.out_proj.lora_scale = lora_scale
35+
36+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
37+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
38+
mlp_module.fc1.lora_scale = lora_scale
39+
mlp_module.fc2.lora_scale = lora_scale
4340

4441

4542
class LoRALinearLayer(nn.Module):

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...models import AutoencoderKL, UNet2DConditionModel
2626
from ...models.lora import adjust_lora_scale_text_encoder
2727
from ...schedulers import KarrasDiffusionSchedulers
28-
from ...utils import deprecate, logging, replace_example_docstring
28+
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
2929
from ...utils.torch_utils import randn_tensor
3030
from ..pipeline_utils import DiffusionPipeline
3131
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -304,7 +304,10 @@ def encode_prompt(
304304
self._lora_scale = lora_scale
305305

306306
# dynamically adjust the LoRA scale
307-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
307+
if not self.use_peft_backend:
308+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
309+
else:
310+
scale_lora_layers(self.text_encoder, lora_scale)
308311

309312
if prompt is not None and isinstance(prompt, str):
310313
batch_size = 1
@@ -429,6 +432,10 @@ def encode_prompt(
429432
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
430433
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
431434

435+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
436+
# Retrieve the original scale by scaling back the LoRA layers
437+
unscale_lora_layers(self.text_encoder)
438+
432439
return prompt_embeds, negative_prompt_embeds
433440

434441
def run_safety_checker(self, image, device, dtype):

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@
2727
from ...models import AutoencoderKL, UNet2DConditionModel
2828
from ...models.lora import adjust_lora_scale_text_encoder
2929
from ...schedulers import KarrasDiffusionSchedulers
30-
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
30+
from ...utils import (
31+
PIL_INTERPOLATION,
32+
deprecate,
33+
logging,
34+
replace_example_docstring,
35+
scale_lora_layers,
36+
unscale_lora_layers,
37+
)
3138
from ...utils.torch_utils import randn_tensor
3239
from ..pipeline_utils import DiffusionPipeline
3340
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -302,7 +309,10 @@ def encode_prompt(
302309
self._lora_scale = lora_scale
303310

304311
# dynamically adjust the LoRA scale
305-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
312+
if not self.use_peft_backend:
313+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
314+
else:
315+
scale_lora_layers(self.text_encoder, lora_scale)
306316

307317
if prompt is not None and isinstance(prompt, str):
308318
batch_size = 1
@@ -427,6 +437,10 @@ def encode_prompt(
427437
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
428438
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
429439

440+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
441+
# Retrieve the original scale by scaling back the LoRA layers
442+
unscale_lora_layers(self.text_encoder)
443+
430444
return prompt_embeds, negative_prompt_embeds
431445

432446
def run_safety_checker(self, image, device, dtype):

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2828
from ...models.lora import adjust_lora_scale_text_encoder
2929
from ...schedulers import KarrasDiffusionSchedulers
30-
from ...utils import (
31-
deprecate,
32-
logging,
33-
replace_example_docstring,
34-
)
30+
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
3531
from ...utils.torch_utils import is_compiled_module, randn_tensor
3632
from ..pipeline_utils import DiffusionPipeline
3733
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
@@ -291,7 +287,10 @@ def encode_prompt(
291287
self._lora_scale = lora_scale
292288

293289
# dynamically adjust the LoRA scale
294-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
290+
if not self.use_peft_backend:
291+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
292+
else:
293+
scale_lora_layers(self.text_encoder, lora_scale)
295294

296295
if prompt is not None and isinstance(prompt, str):
297296
batch_size = 1
@@ -416,6 +415,10 @@ def encode_prompt(
416415
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
417416
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
418417

418+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
419+
# Retrieve the original scale by scaling back the LoRA layers
420+
unscale_lora_layers(self.text_encoder)
421+
419422
return prompt_embeds, negative_prompt_embeds
420423

421424
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
deprecate,
3131
logging,
3232
replace_example_docstring,
33+
scale_lora_layers,
34+
unscale_lora_layers,
3335
)
3436
from ...utils.torch_utils import is_compiled_module, randn_tensor
3537
from ..pipeline_utils import DiffusionPipeline
@@ -315,7 +317,10 @@ def encode_prompt(
315317
self._lora_scale = lora_scale
316318

317319
# dynamically adjust the LoRA scale
318-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
320+
if not self.use_peft_backend:
321+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
322+
else:
323+
scale_lora_layers(self.text_encoder, lora_scale)
319324

320325
if prompt is not None and isinstance(prompt, str):
321326
batch_size = 1
@@ -440,6 +445,10 @@ def encode_prompt(
440445
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
441446
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
442447

448+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
449+
# Retrieve the original scale by scaling back the LoRA layers
450+
unscale_lora_layers(self.text_encoder)
451+
443452
return prompt_embeds, negative_prompt_embeds
444453

445454
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@
2828
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2929
from ...models.lora import adjust_lora_scale_text_encoder
3030
from ...schedulers import KarrasDiffusionSchedulers
31-
from ...utils import (
32-
deprecate,
33-
logging,
34-
replace_example_docstring,
35-
)
31+
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
3632
from ...utils.torch_utils import is_compiled_module, randn_tensor
3733
from ..pipeline_utils import DiffusionPipeline
3834
from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -442,7 +438,10 @@ def encode_prompt(
442438
self._lora_scale = lora_scale
443439

444440
# dynamically adjust the LoRA scale
445-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
441+
if not self.use_peft_backend:
442+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
443+
else:
444+
scale_lora_layers(self.text_encoder, lora_scale)
446445

447446
if prompt is not None and isinstance(prompt, str):
448447
batch_size = 1
@@ -567,6 +566,10 @@ def encode_prompt(
567566
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
568567
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
569568

569+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
570+
# Retrieve the original scale by scaling back the LoRA layers
571+
unscale_lora_layers(self.text_encoder)
572+
570573
return prompt_embeds, negative_prompt_embeds
571574

572575
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
is_invisible_watermark_available,
3737
logging,
3838
replace_example_docstring,
39+
scale_lora_layers,
40+
unscale_lora_layers,
3941
)
4042
from ...utils.torch_utils import is_compiled_module, randn_tensor
4143
from ..pipeline_utils import DiffusionPipeline
@@ -314,8 +316,12 @@ def encode_prompt(
314316
self._lora_scale = lora_scale
315317

316318
# dynamically adjust the LoRA scale
317-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
318-
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
319+
if not self.use_peft_backend:
320+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
321+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
322+
else:
323+
scale_lora_layers(self.text_encoder, lora_scale)
324+
scale_lora_layers(self.text_encoder_2, lora_scale)
319325

320326
prompt = [prompt] if isinstance(prompt, str) else prompt
321327

@@ -452,6 +458,11 @@ def encode_prompt(
452458
bs_embed * num_images_per_prompt, -1
453459
)
454460

461+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
462+
# Retrieve the original scale by scaling back the LoRA layers
463+
unscale_lora_layers(self.text_encoder)
464+
unscale_lora_layers(self.text_encoder_2)
465+
455466
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
456467

457468
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@
3535
)
3636
from ...models.lora import adjust_lora_scale_text_encoder
3737
from ...schedulers import KarrasDiffusionSchedulers
38-
from ...utils import (
39-
logging,
40-
replace_example_docstring,
41-
)
38+
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
4239
from ...utils.torch_utils import is_compiled_module, randn_tensor
4340
from ..pipeline_utils import DiffusionPipeline
4441
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -288,8 +285,12 @@ def encode_prompt(
288285
self._lora_scale = lora_scale
289286

290287
# dynamically adjust the LoRA scale
291-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
292-
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
288+
if not self.use_peft_backend:
289+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
290+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
291+
else:
292+
scale_lora_layers(self.text_encoder, lora_scale)
293+
scale_lora_layers(self.text_encoder_2, lora_scale)
293294

294295
prompt = [prompt] if isinstance(prompt, str) else prompt
295296

@@ -426,6 +427,11 @@ def encode_prompt(
426427
bs_embed * num_images_per_prompt, -1
427428
)
428429

430+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
431+
# Retrieve the original scale by scaling back the LoRA layers
432+
unscale_lora_layers(self.text_encoder)
433+
unscale_lora_layers(self.text_encoder_2)
434+
429435
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
430436

431437
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from ...utils import (
3939
logging,
4040
replace_example_docstring,
41+
scale_lora_layers,
42+
unscale_lora_layers,
4143
)
4244
from ...utils.torch_utils import is_compiled_module, randn_tensor
4345
from ..pipeline_utils import DiffusionPipeline
@@ -326,8 +328,12 @@ def encode_prompt(
326328
self._lora_scale = lora_scale
327329

328330
# dynamically adjust the LoRA scale
329-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
330-
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
331+
if not self.use_peft_backend:
332+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
333+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
334+
else:
335+
scale_lora_layers(self.text_encoder, lora_scale)
336+
scale_lora_layers(self.text_encoder_2, lora_scale)
331337

332338
prompt = [prompt] if isinstance(prompt, str) else prompt
333339

@@ -464,6 +470,11 @@ def encode_prompt(
464470
bs_embed * num_images_per_prompt, -1
465471
)
466472

473+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
474+
# Retrieve the original scale by scaling back the LoRA layers
475+
unscale_lora_layers(self.text_encoder)
476+
unscale_lora_layers(self.text_encoder_2)
477+
467478
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
468479

469480
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ...models import AutoencoderKL, UNet2DConditionModel
2828
from ...models.lora import adjust_lora_scale_text_encoder
2929
from ...schedulers import DDIMScheduler
30-
from ...utils import PIL_INTERPOLATION, deprecate, logging
30+
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
3131
from ...utils.torch_utils import randn_tensor
3232
from ..pipeline_utils import DiffusionPipeline
3333
from .pipeline_output import StableDiffusionPipelineOutput
@@ -308,7 +308,10 @@ def encode_prompt(
308308
self._lora_scale = lora_scale
309309

310310
# dynamically adjust the LoRA scale
311-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
311+
if not self.use_peft_backend:
312+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
313+
else:
314+
scale_lora_layers(self.text_encoder, lora_scale)
312315

313316
if prompt is not None and isinstance(prompt, str):
314317
batch_size = 1
@@ -433,6 +436,10 @@ def encode_prompt(
433436
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
434437
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
435438

439+
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
440+
# Retrieve the original scale by scaling back the LoRA layers
441+
unscale_lora_layers(self.text_encoder)
442+
436443
return prompt_embeds, negative_prompt_embeds
437444

438445
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs

0 commit comments

Comments
 (0)