-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] parse metadata from LoRA and save metadata #11324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there were debates in the past, but I don't remember what the arguments for and against were. Could you quickly sketch why this is becoming increasingly important?
Implementation-wise, it generally LGTM. I was a bit confused about the lora_adapter_config
key vs the lora_metadata
key, what is the difference?
Folks are using different ranks and alphas and finding it to be very effective in practice.
I will rename it to |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviving this conversation again @sayakpaul!
I see some TODOs marked so will take a better look once you let us know when the PR is ready, but changes look good to me and seem sensible from user end. Maybe it might make sense to allow user to also specify the metadata key, or us automatically detect some common names (are there any that you know of?)
@BenjaminBossan To provide some more context, my discussion with Sayak involved training "Control" LoRAs.
Expand for mine and Sayak's conversation
Hello curious 🐱 In short:
- Expand the input projection layer's (patch embedding) input channels. Say we want to train with Canny conditioning on a 64-input channel model (Flux, for example), we'd expand the input projection linear to 128 channels with zeroed weights.
- For the input latent to match the expanded linear proj, you concatenate the latents with the latents of the control condition.
- In the "normal" latent stream (base layer), since we expanded with zeroed out weights, the control condition is effectively not added.
- In the "lora" latent stream, the noisy latents and control latents are projected to the expected inner dim of the model. The lora latent stream outputs are added to the normal latent stream, effectively adding some amount of conditioning information from the new Canny condition
- Now, in Flux, the linear projection goes from
64 channels -> 4096 channels
. With the new conditioning related expansion, it is128 channels -> 4096 channels
. This is the "normal" latent stream. - In the "lora" latent stream, you have
128 channels -> "rank" channels -> 4096 channels
. If the rank for the input projection layer is less than 128 channels, you might lose some information from the control conditioning (since the lora layer then acts as a 1-layer autoencoder). - I've noticed from a limited set of experiments that training with a low input projection rank (does not matter what the rank for rest of the network is) almost always produces worse results and contains random noise in the generations.
- Due to this, in finetrainers, I made it so that the input projection layer has the same rank as
inner_dim
. So, effectively, in the "lora" latent stream, you have128 channels -> 4096 channels -> 4096 channels
. This results in much faster convergence (conclusion made from a very limited set of experiments and it intuitively made sense to me if you compare this low-rank version against a full-rank Control model).
Here are the relevant pieces of code for a quick look:
- Expanding input channels: this
- Creating the LoRA adapter with custom rank for input projection layer vs rest of the network: this
- Channel-wise input concatenation with control latents: this
Convert to impasto style painting |
---|
![]() |
Canny to Image |
![]() |
These results come with just 2500 training steps on a rank=64
lora (except rank=4096
for the input projection).
For this control lora training settings, every lora_alpha
is the same as rank
. So, the lora_alpha
for the input projection is 4096
, while for remaining layers is 64
. The size of LoRA weights is between 300-1000 MB in most cases, which is much less compared to a full control conditioning model.
The problem comes when trying to load such a LoRA with diffusers. Since we only know the ranks, but not the lora_alpha, diffusers config uses the same lora_alpha
as the inferred rank
(which is 4096
since the input projection layer is the first lora layer in the state dict too). As you can imagine, setting an alpha to 4096 on all layers (even the ones that originally had rank = lora_alpha = 64
will result in random noise.
This is a more general problem because a lot of loras on CivitAI and from different trainers are trained with different alpha configurations from the rank. So making the assumption that lora_alpha=rank
(which is our current default behaviour) is incorrect, and having this metadata information will be really helpful.
In order to solve the problem of being able train and run validation in finetrainers
, we just directly serialize the lora adapter config directly into the safetensors metadata. See this and this, and we use a custom function to load the lora weights. Really big kudos to whoever implemented the save_function
allowing custom ones to be provided!
Well, I wanted to check if the changes help solve the issue of Regarding detecting metadata in non-diffusers checkpoints, we already either infer alphas or directly scale the LoRA weights. So, that should already be covered. |
Currently, finetrainers exports lora with a metadata attribute |
SGTM, no rush. Thanks for helping. |
Thanks for explaining this further @a-r-r-o-w
There could be better heuristics, like using the most common rank, but I imagine changing that now would be backwards incompatible.
Any idea who other libraries/apps deal with this lack of info? I thought there was a bit of a standard to multiply the scale into the weights before saving, so that the assumption that |
@BenjaminBossan many trainers just embed the alpha values in the state dict instead of maintaining separate configuration metadata. See here as an example:
|
If no objections, I will button up the PR and make it generally available to the other loaders. But LMK if you have anything to share or have any concerns over this direction. |
Personally, I don't have any issues, just wondering if anything needs to be discussed (again) with the safetensors folks. |
Well, previously there was no problem with them IIRC. It was about time and the complexity. But not we have enough evidence to ship it maintaining the single-file format LoRA checkpoints. |
@sayakpaul I've verified it to work, thanks for working on this. My understanding was that the lora initialization config would also be automatically saved into the exported safetensors as metadata, but it needs to be passed manually to
I think it might be good if we serialize this info automatically in save_lora_weights, if you think there's an easy way to do it. |
Thanks for trying! Well, we're serializing the the entire diffusers/src/diffusers/loaders/lora_pipeline.py Line 5308 in d390d4d
The test below also confirms that: diffusers/tests/lora/test_lora_layers_wan.py Line 143 in d390d4d
If you could provide a minimal reproducer for the issue I would be happy to look into it on priority. |
Yes, what you mentioned works as expected: we need to pass |
Ah I see. I might have an idea. Will update soon. Thanks again for trying it out quickly. |
why can't we add alpha to the state dict like kohya trainer does? |
The code path for |
so because "its harder" we will end up with a solution that requires other tools like comfyUI to implement special support to read these attributes and set them? i expected we would be adding compatibility points and not adding more tech debt. please reconsider just adding alpha to the state dict or leave the work to the community if it is too hard? this kinda thing really takes an eternity to be implemented by tools like Swarm and Forge and ComfyUI while the alpha attribute in dict key would Just Work even with AUTOMATIC1111. please tell me you see the value in this.. |
also it wasnt arrow who brought this up first, it was me, because simpletuner has the option to set rank differently to alpha, which peft has always supported but diffusers does not. setting alpha to 1 always would allow learning rates to stay the same across every rank. the initial request was just to write the alpha into the state dict key. this proposal is heavy handed and requires too much effort from everybody. |
* add lora_alpha and lora_dropout * Apply style fixes * add lora_alpha and lora_dropout * Apply style fixes * revert lora_alpha until #11324 is merged * Apply style fixes * empty commit --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
save_function (`Callable`): | ||
The function to use to save the state dictionary. Useful during distributed training when you need to | ||
replace `torch.save` with another method. Can be configured with the environment variable | ||
`DIFFUSERS_SAVE_MODE`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it's not used.
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) | ||
# Inject framework format. | ||
metadata = {"format": "pt"} | ||
if lora_adapter_metadata is not None: | ||
for key, value in lora_adapter_metadata.items(): | ||
if isinstance(value, set): | ||
lora_adapter_metadata[key] = list(value) | ||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Automatically serialize the metadata when available. Cc: @a-r-r-o-w
@DN6 @a-r-r-o-w this is ready for a review. |
Cc @linoytsaban once this is merged (I think close) let's apply the |
@sayakpaul sounds good! |
@a-r-r-o-w @DN6 a gentle ping on this one. |
@@ -1376,13 +1425,23 @@ def save_lora_weights( | |||
if text_encoder_2_lora_layers: | |||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) | |||
|
|||
if transformer_lora_adapter_metadata is not None: | |||
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if I'm wrong but this looks to me like we're saving the state dict in the metadata?
The metadata should only contain the peft initialization config, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we're not saving the state dict in the metadata.
state_dict
is separate:
state_dict=state_dict, |
metadata is separate:
lora_adapter_metadata=lora_adapter_metadata, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, my mistake was in thinking cls.pack_weights returns the state dict. Only does so when layers
is a torch module.
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
It's a bit confusing why a function who's name implies it's for packaging weights to specific format is being used here (because unless you look at internal implementation, it looks like a dict of str/int/float/list is being passed for weight packing somehow). Anyway, internal implementation detail so can be refactored in future so that the naming makes more sense without having to look into actual implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I will do something _pack_with_prefix
and reuse it.
@@ -646,6 +660,9 @@ def load_lora_weights( | |||
if not is_correct_format: | |||
raise ValueError("Invalid LoRA checkpoint.") | |||
|
|||
from .lora_base import LORA_ADAPTER_METADATA_KEY | |||
|
|||
print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before UNet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: prints to be removed
if prefix is not None: | ||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} | ||
if metadata is not None: | ||
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a little trouble understanding this. I was under the impression that the metadata would be stored with the safetensors metadata feature and not the state dict. If it's stored in the state dict, that would be equivalent to creating a new format, which I don't think we want to do here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When serializing the state dict, we are injecting the state dict with the metadata. It's keys or values are not changing. So, we're not changing the format in any way.
metadata = None | ||
if LORA_ADAPTER_METADATA_KEY in state_dict: | ||
metadata = state_dict[LORA_ADAPTER_METADATA_KEY] | ||
if prefix is not None: | ||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} | ||
if metadata is not None: | ||
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming we still need this code in light of previous comment, I feel like this is null-op. We are checking if the adapter key is part of the state dict, and if so, we set the metadata
variable to the value from the state dict. Then, we check if metadata is not None and assign back in state dict.
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) | ||
|
||
return safetensors.torch.save_file(weights, filename, metadata=metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In light of previous comment, I see that the metadata is being added to the state dict as well as the safetensors metadata. A little confused as to what the intention is with adding the metadata to state dict
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) | ||
lora_config_kwargs = get_peft_kwargs( | ||
rank, | ||
network_alpha_dict=network_alphas, | ||
peft_state_dict=state_dict, | ||
prefix=prefix, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand better why prefix was not needed before but is needed now, could you explain?
try: | ||
lora_config = LoraConfig(**lora_config_kwargs) | ||
except TypeError as e: | ||
raise TypeError(f"`LoraConfig` class could not be instantiated:\n{e}.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, trying to wrap my mind around why this was not needed before but needs to be under try-except now, given that we are not changing anything related to lora loading in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a precaution in case the metadata loaded gets corrupted or something. Removing it isn't a big deal.
def get_peft_kwargs( | ||
rank_dict, | ||
network_alpha_dict, | ||
peft_state_dict, | ||
is_unet=True, | ||
prefix=None, | ||
): | ||
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY | ||
|
||
if LORA_ADAPTER_METADATA_KEY in peft_state_dict: | ||
metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] | ||
else: | ||
metadata = None | ||
if metadata: | ||
if prefix is not None: | ||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} | ||
return metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, now I understand why prefix
is passed. However, this looks to me like it is because the state dict is now also in the metadata. As I understand it, the metadata should only contain the peft initialization config, so we don't really need to remove any prefixes because that's a model-weights related thing, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no state dict in the metadata.
|
||
metadata_key = LORA_ADAPTER_METADATA_KEY | ||
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: | ||
if hasattr(f, "metadata"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a safetensors feature, no? So why do we have to check with hasattr if the method exists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Will remove.
metadata_keys = list(metadata.keys()) | ||
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): | ||
peft_metadata = {k: v for k, v in metadata.items() if k != "format"} | ||
state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should mix the state dict and peft init metadata here. We should limit the state dict to being just weights because otherwise some parts of the code feel like they magically extract some variable.
Instead, we should probably return the state dict and metadata as tuple, and propagate to other methods as required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think this caused most of the confusion and the questions.
@a-r-r-o-w we're not serializing state dict in the metadata or anything like that. We inject the metadata in the state dict safetensors. So, there's that separation. After loading the state dict and processing it, we make the metadata a part of the state dict (standard keys and values) to ease the rest of the process. I think this will answer most of your doubts.
I wanted to avoid that as I felt this might introduce breaking changes but this makes sense as it will make things less confusing. I will work on the changes accordingly. |
What does this PR do?
I know we have revisited this over and over again but this is becoming increasingly important. So, we should consider this on priority.
@a-r-r-o-w brought this issue to me while we were debugging something in Wan LoRA training. So, I started by just modifying the Wan LoRA loader (eventually, the changes will be propagated to other loaders too). Aryan, could you check if this change fixes the issue we were facing?
Admittedly, the PR can be cleaned and modularized a bit but I wanted to get something up quickly to get feedback on the direction.
TODOs