1919
2020from .models .cross_attention import LoRACrossAttnProcessor
2121from .models .modeling_utils import _get_model_file
22- from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , logging
22+ from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , is_safetensors_available , logging
23+
24+
25+ if is_safetensors_available ():
26+ import safetensors
2327
2428
2529logger = logging .get_logger (__name__ )
2630
2731
2832LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
33+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
2934
3035
3136class AttnProcsLayers (torch .nn .Module ):
@@ -136,28 +141,53 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
136141 use_auth_token = kwargs .pop ("use_auth_token" , None )
137142 revision = kwargs .pop ("revision" , None )
138143 subfolder = kwargs .pop ("subfolder" , None )
139- weight_name = kwargs .pop ("weight_name" , LORA_WEIGHT_NAME )
144+ weight_name = kwargs .pop ("weight_name" , None )
140145
141146 user_agent = {
142147 "file_type" : "attn_procs_weights" ,
143148 "framework" : "pytorch" ,
144149 }
145150
151+ model_file = None
146152 if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
147- model_file = _get_model_file (
148- pretrained_model_name_or_path_or_dict ,
149- weights_name = weight_name ,
150- cache_dir = cache_dir ,
151- force_download = force_download ,
152- resume_download = resume_download ,
153- proxies = proxies ,
154- local_files_only = local_files_only ,
155- use_auth_token = use_auth_token ,
156- revision = revision ,
157- subfolder = subfolder ,
158- user_agent = user_agent ,
159- )
160- state_dict = torch .load (model_file , map_location = "cpu" )
153+ if is_safetensors_available ():
154+ if weight_name is None :
155+ weight_name = LORA_WEIGHT_NAME_SAFE
156+ try :
157+ model_file = _get_model_file (
158+ pretrained_model_name_or_path_or_dict ,
159+ weights_name = weight_name ,
160+ cache_dir = cache_dir ,
161+ force_download = force_download ,
162+ resume_download = resume_download ,
163+ proxies = proxies ,
164+ local_files_only = local_files_only ,
165+ use_auth_token = use_auth_token ,
166+ revision = revision ,
167+ subfolder = subfolder ,
168+ user_agent = user_agent ,
169+ )
170+ state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
171+ except EnvironmentError :
172+ if weight_name == LORA_WEIGHT_NAME_SAFE :
173+ weight_name = None
174+ if model_file is None :
175+ if weight_name is None :
176+ weight_name = LORA_WEIGHT_NAME
177+ model_file = _get_model_file (
178+ pretrained_model_name_or_path_or_dict ,
179+ weights_name = weight_name ,
180+ cache_dir = cache_dir ,
181+ force_download = force_download ,
182+ resume_download = resume_download ,
183+ proxies = proxies ,
184+ local_files_only = local_files_only ,
185+ use_auth_token = use_auth_token ,
186+ revision = revision ,
187+ subfolder = subfolder ,
188+ user_agent = user_agent ,
189+ )
190+ state_dict = torch .load (model_file , map_location = "cpu" )
161191 else :
162192 state_dict = pretrained_model_name_or_path_or_dict
163193
@@ -195,8 +225,9 @@ def save_attn_procs(
195225 self ,
196226 save_directory : Union [str , os .PathLike ],
197227 is_main_process : bool = True ,
198- weights_name : str = LORA_WEIGHT_NAME ,
228+ weights_name : str = None ,
199229 save_function : Callable = None ,
230+ safe_serialization : bool = False ,
200231 ):
201232 r"""
202233 Save an attention processor to a directory, so that it can be re-loaded using the
@@ -219,7 +250,13 @@ def save_attn_procs(
219250 return
220251
221252 if save_function is None :
222- save_function = torch .save
253+ if safe_serialization :
254+
255+ def save_function (weights , filename ):
256+ return safetensors .torch .save_file (weights , filename , metadata = {"format" : "pt" })
257+
258+ else :
259+ save_function = torch .save
223260
224261 os .makedirs (save_directory , exist_ok = True )
225262
@@ -237,6 +274,12 @@ def save_attn_procs(
237274 if filename .startswith (weights_no_suffix ) and os .path .isfile (full_filename ) and is_main_process :
238275 os .remove (full_filename )
239276
277+ if weights_name is None :
278+ if safe_serialization :
279+ weights_name = LORA_WEIGHT_NAME_SAFE
280+ else :
281+ weights_name = LORA_WEIGHT_NAME
282+
240283 # Save the model
241284 save_function (state_dict , os .path .join (save_directory , weights_name ))
242285
0 commit comments