Skip to content

Commit 5222294

Browse files
[LoRA] relax lora loading logic (huggingface#4610)
* relax lora loading logic. * cater to the other cases too. * fix: variable name * bring the chaos down. * check * deal with checkpointed files. * Apply suggestions from code review Co-authored-by: apolinário <[email protected]> * style --------- Co-authored-by: apolinário <[email protected]>
1 parent c25c461 commit 5222294

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

src/diffusers/loaders.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import safetensors
2626
import torch
2727
import torch.nn.functional as F
28-
from huggingface_hub import hf_hub_download
28+
from huggingface_hub import hf_hub_download, model_info
2929
from torch import nn
3030

3131
from .utils import (
@@ -1021,6 +1021,13 @@ def lora_state_dict(
10211021
weight_name is not None and weight_name.endswith(".safetensors")
10221022
):
10231023
try:
1024+
# Here we're relaxing the loading check to enable more Inference API
1025+
# friendliness where sometimes, it's not at all possible to automatically
1026+
# determine `weight_name`.
1027+
if weight_name is None:
1028+
weight_name = cls._best_guess_weight_name(
1029+
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
1030+
)
10241031
model_file = _get_model_file(
10251032
pretrained_model_name_or_path_or_dict,
10261033
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
@@ -1041,7 +1048,12 @@ def lora_state_dict(
10411048
# try loading non-safetensors weights
10421049
model_file = None
10431050
pass
1051+
10441052
if model_file is None:
1053+
if weight_name is None:
1054+
weight_name = cls._best_guess_weight_name(
1055+
pretrained_model_name_or_path_or_dict, file_extension=".bin"
1056+
)
10451057
model_file = _get_model_file(
10461058
pretrained_model_name_or_path_or_dict,
10471059
weights_name=weight_name or LORA_WEIGHT_NAME,
@@ -1077,6 +1089,31 @@ def lora_state_dict(
10771089

10781090
return state_dict, network_alphas
10791091

1092+
@classmethod
1093+
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
1094+
targeted_files = []
1095+
1096+
if os.path.isfile(pretrained_model_name_or_path_or_dict):
1097+
return
1098+
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
1099+
targeted_files = [
1100+
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
1101+
]
1102+
else:
1103+
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
1104+
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
1105+
1106+
if len(targeted_files) == 0:
1107+
return
1108+
1109+
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
1110+
if len(targeted_files) > 1:
1111+
raise ValueError(
1112+
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
1113+
)
1114+
weight_name = targeted_files[0]
1115+
return weight_name
1116+
10801117
@classmethod
10811118
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
10821119
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)

0 commit comments

Comments
 (0)