Skip to content

[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
cd88a4b
updates
sayakpaul Dec 11, 2024
b694ca4
updates
sayakpaul Dec 11, 2024
1db7503
updates
sayakpaul Dec 11, 2024
6134491
updates
sayakpaul Dec 11, 2024
db827b5
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 11, 2024
ac29785
notebooks revert
sayakpaul Dec 11, 2024
3f4a3fc
resolve conflicts
sayakpaul Dec 15, 2024
b6db978
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 19, 2024
c44f7a3
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 20, 2024
876132a
fix-copies.
sayakpaul Dec 20, 2024
bb7c09a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
e6043a0
seeing
sayakpaul Dec 25, 2024
7ca7493
fix
sayakpaul Dec 25, 2024
ec44f9a
revert
sayakpaul Dec 25, 2024
343b2d2
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
615e372
fixes
sayakpaul Dec 25, 2024
e2e3ea0
fixes
sayakpaul Dec 25, 2024
f9dd64c
fixes
sayakpaul Dec 25, 2024
a01cb45
remove print
sayakpaul Dec 25, 2024
da96621
fix
sayakpaul Dec 25, 2024
a91138d
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 27, 2024
83ad82b
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 2, 2025
be187da
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 5, 2025
726e492
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 7, 2025
3efdc58
fix conflicts
sayakpaul Jan 13, 2025
cf50148
conflicts ii.
sayakpaul Jan 13, 2025
b2afc10
updates
sayakpaul Jan 13, 2025
96eced3
fixes
sayakpaul Jan 13, 2025
b4be719
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 13, 2025
8bf1173
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Feb 9, 2025
0e43b55
Merge branch 'main' into improve-lora-warning-msg
hlky Mar 6, 2025
1e4dbbc
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 9, 2025
9eb460f
better filtering of prefix.
sayakpaul Mar 10, 2025
279ee91
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
6240876
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
cf9027a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
  • Loading branch information
sayakpaul committed Dec 25, 2024
commit e2e3ea09f98ce1ace7940511b6c23c406e82d356
82 changes: 34 additions & 48 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,15 @@ def load_lora_into_unet(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Comment on lines +301 to +309
Copy link
Member Author

@sayakpaul sayakpaul Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case prefix is not None and no prefix matched state dict keys are found, we log from the load_lora_adapter() method.

This way, we cover for both load_lora_weights() which is pipeline-level and load_lora_adapter() which is model-level.


@classmethod
def load_lora_into_text_encoder(
Expand Down Expand Up @@ -828,19 +824,15 @@ def load_lora_into_unet(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
Expand Down Expand Up @@ -1900,17 +1892,14 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def _load_norm_into_transformer(
Expand Down Expand Up @@ -2495,17 +2484,14 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

if prefix is not None:
keys = list(state_dict.keys())
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
Expand Down Expand Up @@ -369,6 +366,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)

def save_lora_adapter(
self,
save_directory,
Expand Down
Loading