@@ -310,26 +310,31 @@ def from_pretrained(
310310 )
311311
312312 # Load model
313- if os .path .isdir (pretrained_model_name_or_path ):
313+ pretrained_path_with_subfolder = (
314+ pretrained_model_name_or_path
315+ if subfolder is None
316+ else os .path .join (pretrained_model_name_or_path , subfolder )
317+ )
318+ if os .path .isdir (pretrained_path_with_subfolder ):
314319 if from_pt :
315- if not os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
320+ if not os .path .isfile (os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )):
316321 raise EnvironmentError (
317- f"Error no file named { WEIGHTS_NAME } found in directory { pretrained_model_name_or_path } "
322+ f"Error no file named { WEIGHTS_NAME } found in directory { pretrained_path_with_subfolder } "
318323 )
319- model_file = os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )
320- elif os .path .isfile (os .path .join (pretrained_model_name_or_path , FLAX_WEIGHTS_NAME )):
324+ model_file = os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )
325+ elif os .path .isfile (os .path .join (pretrained_path_with_subfolder , FLAX_WEIGHTS_NAME )):
321326 # Load from a Flax checkpoint
322- model_file = os .path .join (pretrained_model_name_or_path , FLAX_WEIGHTS_NAME )
327+ model_file = os .path .join (pretrained_path_with_subfolder , FLAX_WEIGHTS_NAME )
323328 # Check if pytorch weights exist instead
324- elif os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
329+ elif os .path .isfile (os .path .join (pretrained_path_with_subfolder , WEIGHTS_NAME )):
325330 raise EnvironmentError (
326- f"{ WEIGHTS_NAME } file found in directory { pretrained_model_name_or_path } . Please load the model"
331+ f"{ WEIGHTS_NAME } file found in directory { pretrained_path_with_subfolder } . Please load the model"
327332 " using `from_pt=True`."
328333 )
329334 else :
330335 raise EnvironmentError (
331336 f"Error no file named { FLAX_WEIGHTS_NAME } or { WEIGHTS_NAME } found in directory "
332- f"{ pretrained_model_name_or_path } ."
337+ f"{ pretrained_path_with_subfolder } ."
333338 )
334339 else :
335340 try :
0 commit comments