2525import safetensors
2626import torch
2727import torch .nn .functional as F
28- from huggingface_hub import hf_hub_download
28+ from huggingface_hub import hf_hub_download , model_info
2929from torch import nn
3030
3131from .utils import (
@@ -1021,6 +1021,13 @@ def lora_state_dict(
10211021 weight_name is not None and weight_name .endswith (".safetensors" )
10221022 ):
10231023 try :
1024+ # Here we're relaxing the loading check to enable more Inference API
1025+ # friendliness where sometimes, it's not at all possible to automatically
1026+ # determine `weight_name`.
1027+ if weight_name is None :
1028+ weight_name = cls ._best_guess_weight_name (
1029+ pretrained_model_name_or_path_or_dict , file_extension = ".safetensors"
1030+ )
10241031 model_file = _get_model_file (
10251032 pretrained_model_name_or_path_or_dict ,
10261033 weights_name = weight_name or LORA_WEIGHT_NAME_SAFE ,
@@ -1041,7 +1048,12 @@ def lora_state_dict(
10411048 # try loading non-safetensors weights
10421049 model_file = None
10431050 pass
1051+
10441052 if model_file is None :
1053+ if weight_name is None :
1054+ weight_name = cls ._best_guess_weight_name (
1055+ pretrained_model_name_or_path_or_dict , file_extension = ".bin"
1056+ )
10451057 model_file = _get_model_file (
10461058 pretrained_model_name_or_path_or_dict ,
10471059 weights_name = weight_name or LORA_WEIGHT_NAME ,
@@ -1077,6 +1089,31 @@ def lora_state_dict(
10771089
10781090 return state_dict , network_alphas
10791091
1092+ @classmethod
1093+ def _best_guess_weight_name (cls , pretrained_model_name_or_path_or_dict , file_extension = ".safetensors" ):
1094+ targeted_files = []
1095+
1096+ if os .path .isfile (pretrained_model_name_or_path_or_dict ):
1097+ return
1098+ elif os .path .isdir (pretrained_model_name_or_path_or_dict ):
1099+ targeted_files = [
1100+ f for f in os .listdir (pretrained_model_name_or_path_or_dict ) if f .endswith (file_extension )
1101+ ]
1102+ else :
1103+ files_in_repo = model_info (pretrained_model_name_or_path_or_dict ).siblings
1104+ targeted_files = [f .rfilename for f in files_in_repo if f .rfilename .endswith (file_extension )]
1105+
1106+ if len (targeted_files ) == 0 :
1107+ return
1108+
1109+ targeted_files = list (filter (lambda x : "scheduler" not in x and "optimizer" not in x , targeted_files ))
1110+ if len (targeted_files ) > 1 :
1111+ raise ValueError (
1112+ f"Provided path contains more than one weights file in the { file_extension } format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in { pretrained_model_name_or_path_or_dict } ."
1113+ )
1114+ weight_name = targeted_files [0 ]
1115+ return weight_name
1116+
10801117 @classmethod
10811118 def _map_sgm_blocks_to_diffusers (cls , state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
10821119 is_all_unet = all (k .startswith ("lora_unet" ) for k in state_dict )
0 commit comments