Skip to content

Commit 0de35e4

Browse files
[Tests] Tighten up LoRA loading relaxation (huggingface#4787)
* debugging * better logic for filtering. * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 0d81e54 commit 0de35e4

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/loaders.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,11 +1102,17 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
11021102
else:
11031103
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
11041104
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
1105-
11061105
if len(targeted_files) == 0:
11071106
return
11081107

1109-
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
1108+
# "scheduler" does not correspond to a LoRA checkpoint.
1109+
# "optimizer" does not correspond to a LoRA checkpoint
1110+
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
1111+
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
1112+
targeted_files = list(
1113+
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
1114+
)
1115+
11101116
if len(targeted_files) > 1:
11111117
raise ValueError(
11121118
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."

0 commit comments

Comments
 (0)