@@ -377,7 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
377
377
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
378
378
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
379
379
setting this argument to `True` will raise an error.
380
-
380
+ return_cached_folder (`bool`, *optional*, defaults to `False`):
381
+ If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
381
382
kwargs (remaining dictionary of keyword arguments, *optional*):
382
383
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
383
384
specific pipeline class. The overwritten components are then directly passed to the pipelines
@@ -430,33 +431,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
430
431
sess_options = kwargs .pop ("sess_options" , None )
431
432
device_map = kwargs .pop ("device_map" , None )
432
433
low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
433
-
434
- if low_cpu_mem_usage and not is_accelerate_available ():
435
- low_cpu_mem_usage = False
436
- logger .warning (
437
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
438
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
439
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
440
- " install accelerate\n ```\n ."
441
- )
442
-
443
- if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
444
- raise NotImplementedError (
445
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
446
- " `device_map=None`."
447
- )
448
-
449
- if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
450
- raise NotImplementedError (
451
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
452
- " `low_cpu_mem_usage=False`."
453
- )
454
-
455
- if low_cpu_mem_usage is False and device_map is not None :
456
- raise ValueError (
457
- f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
458
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
459
- )
434
+ return_cached_folder = kwargs .pop ("return_cached_folder" , False )
460
435
461
436
# 1. Download the checkpoints and configs
462
437
# use snapshot download here to get it working from from_pretrained
@@ -585,6 +560,33 @@ def load_module(name, value):
585
560
f"Keyword arguments { unused_kwargs } are not expected by { pipeline_class .__name__ } and will be ignored."
586
561
)
587
562
563
+ if low_cpu_mem_usage and not is_accelerate_available ():
564
+ low_cpu_mem_usage = False
565
+ logger .warning (
566
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
567
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
568
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
569
+ " install accelerate\n ```\n ."
570
+ )
571
+
572
+ if device_map is not None and not is_torch_version (">=" , "1.9.0" ):
573
+ raise NotImplementedError (
574
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
575
+ " `device_map=None`."
576
+ )
577
+
578
+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
579
+ raise NotImplementedError (
580
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
581
+ " `low_cpu_mem_usage=False`."
582
+ )
583
+
584
+ if low_cpu_mem_usage is False and device_map is not None :
585
+ raise ValueError (
586
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={ device_map } for loading and"
587
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
588
+ )
589
+
588
590
# import it here to avoid circular import
589
591
from diffusers import pipelines
590
592
@@ -704,6 +706,9 @@ def load_module(name, value):
704
706
705
707
# 5. Instantiate the pipeline
706
708
model = pipeline_class (** init_kwargs )
709
+
710
+ if return_cached_folder :
711
+ return model , cached_folder
707
712
return model
708
713
709
714
@staticmethod
0 commit comments