3030from ..utils import (
3131 CONFIG_NAME ,
3232 DIFFUSERS_CACHE ,
33+ FLAX_WEIGHTS_NAME ,
3334 HF_HUB_OFFLINE ,
3435 HUGGINGFACE_CO_RESOLVE_ENDPOINT ,
3536 SAFETENSORS_WEIGHTS_NAME ,
@@ -335,6 +336,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
335336 The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
336337 git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
337338 identifier allowed by git.
339+ from_flax (`bool`, *optional*, defaults to `False`):
340+ Load the model weights from a Flax checkpoint save file.
338341 subfolder (`str`, *optional*, defaults to `""`):
339342 In case the relevant files are located inside a subfolder of the model repo (either remote in
340343 huggingface.co or downloaded locally), you can specify the folder name here.
@@ -375,6 +378,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
375378 cache_dir = kwargs .pop ("cache_dir" , DIFFUSERS_CACHE )
376379 ignore_mismatched_sizes = kwargs .pop ("ignore_mismatched_sizes" , False )
377380 force_download = kwargs .pop ("force_download" , False )
381+ from_flax = kwargs .pop ("from_flax" , False )
378382 resume_download = kwargs .pop ("resume_download" , False )
379383 proxies = kwargs .pop ("proxies" , None )
380384 output_loading_info = kwargs .pop ("output_loading_info" , False )
@@ -433,27 +437,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
433437 # Load model
434438
435439 model_file = None
436- if is_safetensors_available ():
437- try :
438- model_file = cls ._get_model_file (
439- pretrained_model_name_or_path ,
440- weights_name = SAFETENSORS_WEIGHTS_NAME ,
441- cache_dir = cache_dir ,
442- force_download = force_download ,
443- resume_download = resume_download ,
444- proxies = proxies ,
445- local_files_only = local_files_only ,
446- use_auth_token = use_auth_token ,
447- revision = revision ,
448- subfolder = subfolder ,
449- user_agent = user_agent ,
450- )
451- except :
452- pass
453- if model_file is None :
440+ if from_flax :
454441 model_file = cls ._get_model_file (
455442 pretrained_model_name_or_path ,
456- weights_name = WEIGHTS_NAME ,
443+ weights_name = FLAX_WEIGHTS_NAME ,
457444 cache_dir = cache_dir ,
458445 force_download = force_download ,
459446 resume_download = resume_download ,
@@ -464,10 +451,105 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
464451 subfolder = subfolder ,
465452 user_agent = user_agent ,
466453 )
454+ config , unused_kwargs = cls .load_config (
455+ config_path ,
456+ cache_dir = cache_dir ,
457+ return_unused_kwargs = True ,
458+ force_download = force_download ,
459+ resume_download = resume_download ,
460+ proxies = proxies ,
461+ local_files_only = local_files_only ,
462+ use_auth_token = use_auth_token ,
463+ revision = revision ,
464+ subfolder = subfolder ,
465+ device_map = device_map ,
466+ ** kwargs ,
467+ )
468+ model = cls .from_config (config , ** unused_kwargs )
469+
470+ # Convert the weights
471+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
467472
468- if low_cpu_mem_usage :
469- # Instantiate model with empty weights
470- with accelerate .init_empty_weights ():
473+ model = load_flax_checkpoint_in_pytorch_model (model , model_file )
474+ else :
475+ if is_safetensors_available ():
476+ try :
477+ model_file = cls ._get_model_file (
478+ pretrained_model_name_or_path ,
479+ weights_name = SAFETENSORS_WEIGHTS_NAME ,
480+ cache_dir = cache_dir ,
481+ force_download = force_download ,
482+ resume_download = resume_download ,
483+ proxies = proxies ,
484+ local_files_only = local_files_only ,
485+ use_auth_token = use_auth_token ,
486+ revision = revision ,
487+ subfolder = subfolder ,
488+ user_agent = user_agent ,
489+ )
490+ except :
491+ pass
492+ if model_file is None :
493+ model_file = cls ._get_model_file (
494+ pretrained_model_name_or_path ,
495+ weights_name = WEIGHTS_NAME ,
496+ cache_dir = cache_dir ,
497+ force_download = force_download ,
498+ resume_download = resume_download ,
499+ proxies = proxies ,
500+ local_files_only = local_files_only ,
501+ use_auth_token = use_auth_token ,
502+ revision = revision ,
503+ subfolder = subfolder ,
504+ user_agent = user_agent ,
505+ )
506+
507+ if low_cpu_mem_usage :
508+ # Instantiate model with empty weights
509+ with accelerate .init_empty_weights ():
510+ config , unused_kwargs = cls .load_config (
511+ config_path ,
512+ cache_dir = cache_dir ,
513+ return_unused_kwargs = True ,
514+ force_download = force_download ,
515+ resume_download = resume_download ,
516+ proxies = proxies ,
517+ local_files_only = local_files_only ,
518+ use_auth_token = use_auth_token ,
519+ revision = revision ,
520+ subfolder = subfolder ,
521+ device_map = device_map ,
522+ ** kwargs ,
523+ )
524+ model = cls .from_config (config , ** unused_kwargs )
525+
526+ # if device_map is None, load the state dict and move the params from meta device to the cpu
527+ if device_map is None :
528+ param_device = "cpu"
529+ state_dict = load_state_dict (model_file )
530+ # move the params from meta device to cpu
531+ for param_name , param in state_dict .items ():
532+ accepts_dtype = "dtype" in set (
533+ inspect .signature (set_module_tensor_to_device ).parameters .keys ()
534+ )
535+ if accepts_dtype :
536+ set_module_tensor_to_device (
537+ model , param_name , param_device , value = param , dtype = torch_dtype
538+ )
539+ else :
540+ set_module_tensor_to_device (model , param_name , param_device , value = param )
541+ else : # else let accelerate handle loading and dispatching.
542+ # Load weights and dispatch according to the device_map
543+ # by deafult the device_map is None and the weights are loaded on the CPU
544+ accelerate .load_checkpoint_and_dispatch (model , model_file , device_map , dtype = torch_dtype )
545+
546+ loading_info = {
547+ "missing_keys" : [],
548+ "unexpected_keys" : [],
549+ "mismatched_keys" : [],
550+ "error_msgs" : [],
551+ }
552+ else :
471553 config , unused_kwargs = cls .load_config (
472554 config_path ,
473555 cache_dir = cache_dir ,
@@ -484,61 +566,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
484566 )
485567 model = cls .from_config (config , ** unused_kwargs )
486568
487- # if device_map is Non,e load the state dict on move the params from meta device to the cpu
488- if device_map is None :
489- param_device = "cpu"
490569 state_dict = load_state_dict (model_file )
491- # move the parms from meta device to cpu
492- for param_name , param in state_dict .items ():
493- accepts_dtype = "dtype" in set (inspect .signature (set_module_tensor_to_device ).parameters .keys ())
494- if accepts_dtype :
495- set_module_tensor_to_device (model , param_name , param_device , value = param , dtype = torch_dtype )
496- else :
497- set_module_tensor_to_device (model , param_name , param_device , value = param )
498- else : # else let accelerate handle loading and dispatching.
499- # Load weights and dispatch according to the device_map
500- # by deafult the device_map is None and the weights are loaded on the CPU
501- accelerate .load_checkpoint_and_dispatch (model , model_file , device_map , dtype = torch_dtype )
502-
503- loading_info = {
504- "missing_keys" : [],
505- "unexpected_keys" : [],
506- "mismatched_keys" : [],
507- "error_msgs" : [],
508- }
509- else :
510- config , unused_kwargs = cls .load_config (
511- config_path ,
512- cache_dir = cache_dir ,
513- return_unused_kwargs = True ,
514- force_download = force_download ,
515- resume_download = resume_download ,
516- proxies = proxies ,
517- local_files_only = local_files_only ,
518- use_auth_token = use_auth_token ,
519- revision = revision ,
520- subfolder = subfolder ,
521- device_map = device_map ,
522- ** kwargs ,
523- )
524- model = cls .from_config (config , ** unused_kwargs )
525570
526- state_dict = load_state_dict (model_file )
527-
528- model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
529- model ,
530- state_dict ,
531- model_file ,
532- pretrained_model_name_or_path ,
533- ignore_mismatched_sizes = ignore_mismatched_sizes ,
534- )
571+ model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
572+ model ,
573+ state_dict ,
574+ model_file ,
575+ pretrained_model_name_or_path ,
576+ ignore_mismatched_sizes = ignore_mismatched_sizes ,
577+ )
535578
536- loading_info = {
537- "missing_keys" : missing_keys ,
538- "unexpected_keys" : unexpected_keys ,
539- "mismatched_keys" : mismatched_keys ,
540- "error_msgs" : error_msgs ,
541- }
579+ loading_info = {
580+ "missing_keys" : missing_keys ,
581+ "unexpected_keys" : unexpected_keys ,
582+ "mismatched_keys" : mismatched_keys ,
583+ "error_msgs" : error_msgs ,
584+ }
542585
543586 if torch_dtype is not None and not isinstance (torch_dtype , torch .dtype ):
544587 raise ValueError (
0 commit comments