26
26
from requests import HTTPError
27
27
28
28
from . import __version__
29
+ from .hub_utils import send_telemetry
29
30
from .utils import (
30
31
CONFIG_NAME ,
31
32
DIFFUSERS_CACHE ,
@@ -400,7 +401,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
400
401
model_file = None
401
402
if is_safetensors_available ():
402
403
try :
403
- model_file = _get_model_file (
404
+ model_file = cls . _get_model_file (
404
405
pretrained_model_name_or_path ,
405
406
weights_name = SAFETENSORS_WEIGHTS_NAME ,
406
407
cache_dir = cache_dir ,
@@ -416,7 +417,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
416
417
except :
417
418
pass
418
419
if model_file is None :
419
- model_file = _get_model_file (
420
+ model_file = cls . _get_model_file (
420
421
pretrained_model_name_or_path ,
421
422
weights_name = WEIGHTS_NAME ,
422
423
cache_dir = cache_dir ,
@@ -531,6 +532,100 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
531
532
532
533
return model
533
534
535
+ @classmethod
536
+ def _get_model_file (
537
+ cls ,
538
+ pretrained_model_name_or_path ,
539
+ * ,
540
+ weights_name ,
541
+ subfolder ,
542
+ cache_dir ,
543
+ force_download ,
544
+ proxies ,
545
+ resume_download ,
546
+ local_files_only ,
547
+ use_auth_token ,
548
+ user_agent ,
549
+ revision ,
550
+ ):
551
+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
552
+ if os .path .isdir (pretrained_model_name_or_path ):
553
+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , weights_name )):
554
+ # Load from a PyTorch checkpoint
555
+ model_file = os .path .join (pretrained_model_name_or_path , weights_name )
556
+ elif subfolder is not None and os .path .isfile (
557
+ os .path .join (pretrained_model_name_or_path , subfolder , weights_name )
558
+ ):
559
+ model_file = os .path .join (pretrained_model_name_or_path , subfolder , weights_name )
560
+ else :
561
+ raise EnvironmentError (
562
+ f"Error no file named { weights_name } found in directory { pretrained_model_name_or_path } ."
563
+ )
564
+ send_telemetry (
565
+ {"model_class" : cls .__name__ , "model_path" : "local" , "framework" : "pytorch" },
566
+ name = "diffusers_from_pretrained" ,
567
+ )
568
+ return model_file
569
+ else :
570
+ try :
571
+ # Load from URL or cache if already cached
572
+ model_file = hf_hub_download (
573
+ pretrained_model_name_or_path ,
574
+ filename = weights_name ,
575
+ cache_dir = cache_dir ,
576
+ force_download = force_download ,
577
+ proxies = proxies ,
578
+ resume_download = resume_download ,
579
+ local_files_only = local_files_only ,
580
+ use_auth_token = use_auth_token ,
581
+ user_agent = user_agent ,
582
+ subfolder = subfolder ,
583
+ revision = revision ,
584
+ )
585
+ send_telemetry (
586
+ {"model_class" : cls .__name__ , "model_path" : "hub" , "framework" : "pytorch" },
587
+ name = "diffusers_from_pretrained" ,
588
+ )
589
+ return model_file
590
+
591
+ except RepositoryNotFoundError :
592
+ raise EnvironmentError (
593
+ f"{ pretrained_model_name_or_path } is not a local folder and is not a valid model identifier "
594
+ "listed on 'https://huggingface.co/models'\n If this is a private repository, make sure to pass a "
595
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
596
+ "login`."
597
+ )
598
+ except RevisionNotFoundError :
599
+ raise EnvironmentError (
600
+ f"{ revision } is not a valid git identifier (branch name, tag name or commit id) that exists for "
601
+ "this model name. Check the model page at "
602
+ f"'https://huggingface.co/{ pretrained_model_name_or_path } ' for available revisions."
603
+ )
604
+ except EntryNotFoundError :
605
+ raise EnvironmentError (
606
+ f"{ pretrained_model_name_or_path } does not appear to have a file named { weights_name } ."
607
+ )
608
+ except HTTPError as err :
609
+ raise EnvironmentError (
610
+ "There was a specific connection error when trying to load"
611
+ f" { pretrained_model_name_or_path } :\n { err } "
612
+ )
613
+ except ValueError :
614
+ raise EnvironmentError (
615
+ f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load this model, couldn't find it"
616
+ f" in the cached files and it looks like { pretrained_model_name_or_path } is not the path to a"
617
+ f" directory containing a file named { weights_name } or"
618
+ " \n Checkout your internet connection or see how to run the library in"
619
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
620
+ )
621
+ except EnvironmentError :
622
+ raise EnvironmentError (
623
+ f"Can't load the model for '{ pretrained_model_name_or_path } '. If you were trying to load it from "
624
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
625
+ f"Otherwise, make sure '{ pretrained_model_name_or_path } ' is the correct path to a directory "
626
+ f"containing a file named { weights_name } "
627
+ )
628
+
534
629
@classmethod
535
630
def _load_pretrained_model (
536
631
cls ,
0 commit comments