@@ -670,14 +670,98 @@ def is_saveable_module(name, value):
670670 create_pr = create_pr ,
671671 )
672672
673- def to (
674- self ,
675- torch_device : Optional [Union [str , torch .device ]] = None ,
676- torch_dtype : Optional [torch .dtype ] = None ,
677- silence_dtype_warnings : bool = False ,
678- ):
679- if torch_device is None and torch_dtype is None :
680- return self
673+ def to (self , * args , ** kwargs ):
674+ r"""
675+ Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
676+ arguments of `self.to(*args, **kwargs).`
677+
678+ <Tip>
679+
680+ If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
681+ the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
682+
683+ </Tip>
684+
685+
686+ Here are the ways to call `to`:
687+
688+ - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
689+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
690+ - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
691+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
692+ - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
693+ specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
694+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
695+
696+ Arguments:
697+ dtype (`torch.dtype`, *optional*):
698+ Returns a pipeline with the specified
699+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
700+ device (`torch.Device`, *optional*):
701+ Returns a pipeline with the specified
702+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
703+ silence_dtype_warnings (`str`, *optional*, defaults to `False`):
704+ Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
705+
706+ Returns:
707+ [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
708+ """
709+
710+ torch_dtype = kwargs .pop ("torch_dtype" , None )
711+ if torch_dtype is not None :
712+ deprecate ("torch_dtype" , "0.25.0" , "" )
713+ torch_device = kwargs .pop ("torch_device" , None )
714+ if torch_device is not None :
715+ deprecate ("torch_device" , "0.25.0" , "" )
716+
717+ dtype_kwarg = kwargs .pop ("dtype" , None )
718+ device_kwarg = kwargs .pop ("device" , None )
719+ silence_dtype_warnings = kwargs .pop ("silence_dtype_warnings" , False )
720+
721+ if torch_dtype is not None and dtype_kwarg is not None :
722+ raise ValueError (
723+ "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
724+ )
725+
726+ dtype = torch_dtype or dtype_kwarg
727+
728+ if torch_device is not None and device_kwarg is not None :
729+ raise ValueError (
730+ "You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
731+ )
732+
733+ device = torch_device or device_kwarg
734+
735+ dtype_arg = None
736+ device_arg = None
737+ if len (args ) == 1 :
738+ if isinstance (args [0 ], torch .dtype ):
739+ dtype_arg = args [0 ]
740+ else :
741+ device_arg = torch .device (args [0 ]) if args [0 ] is not None else None
742+ elif len (args ) == 2 :
743+ if isinstance (args [0 ], torch .dtype ):
744+ raise ValueError (
745+ "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
746+ )
747+ device_arg = torch .device (args [0 ]) if args [0 ] is not None else None
748+ dtype_arg = args [1 ]
749+ elif len (args ) > 2 :
750+ raise ValueError ("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`" )
751+
752+ if dtype is not None and dtype_arg is not None :
753+ raise ValueError (
754+ "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
755+ )
756+
757+ dtype = dtype or dtype_arg
758+
759+ if device is not None and device_arg is not None :
760+ raise ValueError (
761+ "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
762+ )
763+
764+ device = device or device_arg
681765
682766 # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
683767 def module_is_sequentially_offloaded (module ):
@@ -698,14 +782,14 @@ def module_is_offloaded(module):
698782 pipeline_is_sequentially_offloaded = any (
699783 module_is_sequentially_offloaded (module ) for _ , module in self .components .items ()
700784 )
701- if pipeline_is_sequentially_offloaded and torch_device and torch .device (torch_device ).type == "cuda" :
785+ if pipeline_is_sequentially_offloaded and device and torch .device (device ).type == "cuda" :
702786 raise ValueError (
703787 "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
704788 )
705789
706790 # Display a warning in this case (the operation succeeds but the benefits are lost)
707791 pipeline_is_offloaded = any (module_is_offloaded (module ) for _ , module in self .components .items ())
708- if pipeline_is_offloaded and torch_device and torch .device (torch_device ).type == "cuda" :
792+ if pipeline_is_offloaded and device and torch .device (device ).type == "cuda" :
709793 logger .warning (
710794 f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components { ', ' .join (self .components .keys ())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
711795 )
@@ -718,26 +802,26 @@ def module_is_offloaded(module):
718802 for module in modules :
719803 is_loaded_in_8bit = hasattr (module , "is_loaded_in_8bit" ) and module .is_loaded_in_8bit
720804
721- if is_loaded_in_8bit and torch_dtype is not None :
805+ if is_loaded_in_8bit and dtype is not None :
722806 logger .warning (
723807 f"The module '{ module .__class__ .__name__ } ' has been loaded in 8bit and conversion to { torch_dtype } is not yet supported. Module is still in 8bit precision."
724808 )
725809
726- if is_loaded_in_8bit and torch_device is not None :
810+ if is_loaded_in_8bit and device is not None :
727811 logger .warning (
728812 f"The module '{ module .__class__ .__name__ } ' has been loaded in 8bit and moving it to { torch_dtype } via `.to()` is not yet supported. Module is still on { module .device } ."
729813 )
730814 else :
731- module .to (torch_device , torch_dtype )
815+ module .to (device , dtype )
732816
733817 if (
734818 module .dtype == torch .float16
735- and str (torch_device ) in ["cpu" ]
819+ and str (device ) in ["cpu" ]
736820 and not silence_dtype_warnings
737821 and not is_offloaded
738822 ):
739823 logger .warning (
740- "Pipelines loaded with `torch_dtype =torch.float16` cannot run with `cpu` device. It"
824+ "Pipelines loaded with `dtype =torch.float16` cannot run with `cpu` device. It"
741825 " is not recommended to move them to `cpu` as running them will fail. Please make"
742826 " sure to use an accelerator to run the pipeline in inference, due to the lack of"
743827 " support for`float16` operations on this device in PyTorch. Please, remove the"
@@ -760,6 +844,21 @@ def device(self) -> torch.device:
760844
761845 return torch .device ("cpu" )
762846
847+ @property
848+ def dtype (self ) -> torch .dtype :
849+ r"""
850+ Returns:
851+ `torch.dtype`: The torch dtype on which the pipeline is located.
852+ """
853+ module_names , _ = self ._get_signature_keys (self )
854+ modules = [getattr (self , n , None ) for n in module_names ]
855+ modules = [m for m in modules if isinstance (m , torch .nn .Module )]
856+
857+ for module in modules :
858+ return module .dtype
859+
860+ return torch .float32
861+
763862 @classmethod
764863 def from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
765864 r"""
@@ -1222,12 +1321,19 @@ def _execution_device(self):
12221321 return torch .device (module ._hf_hook .execution_device )
12231322 return self .device
12241323
1225- def enable_model_cpu_offload (self , gpu_id : int = 0 , device : Union [torch .device , str ] = "cuda" ):
1324+ def enable_model_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [torch .device , str ] = "cuda" ):
12261325 r"""
12271326 Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
12281327 to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
12291328 method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
12301329 `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1330+
1331+ Arguments:
1332+ gpu_id (`int`, *optional*):
1333+ The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1334+ device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1335+ The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1336+ default to "cuda".
12311337 """
12321338 if self .model_cpu_offload_seq is None :
12331339 raise ValueError (
@@ -1239,7 +1345,20 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,
12391345 else :
12401346 raise ImportError ("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." )
12411347
1242- device = torch .device (f"cuda:{ gpu_id } " )
1348+ torch_device = torch .device (device )
1349+ device_index = torch_device .index
1350+
1351+ if gpu_id is not None and device_index is not None :
1352+ raise ValueError (
1353+ f"You have passed both `gpu_id`={ gpu_id } and an index as part of the passed device `device`={ device } "
1354+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={ torch_device .type } "
1355+ )
1356+
1357+ # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1358+ self ._offload_gpu_id = gpu_id or torch_device .index or self ._offload_gpu_id or 0
1359+
1360+ device_type = torch_device .type
1361+ device = torch .device (f"{ device_type } :{ self ._offload_gpu_id } " )
12431362
12441363 if self .device .type != "cpu" :
12451364 self .to ("cpu" , silence_dtype_warnings = True )
@@ -1274,7 +1393,10 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,
12741393
12751394 def maybe_free_model_hooks (self ):
12761395 r"""
1277- TODO: Better doc string
1396+ Function that offloads all components, removes all model hooks that were added when using
1397+ `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
1398+ is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
1399+ functions correctly when applying enable_model_cpu_offload.
12781400 """
12791401 if not hasattr (self , "_all_hooks" ) or len (self ._all_hooks ) == 0 :
12801402 # `enable_model_cpu_offload` has not be called, so silently do nothing
@@ -1288,21 +1410,40 @@ def maybe_free_model_hooks(self):
12881410 # make sure the model is in the same state as before calling it
12891411 self .enable_model_cpu_offload ()
12901412
1291- def enable_sequential_cpu_offload (self , gpu_id : int = 0 , device : Union [torch .device , str ] = "cuda" ):
1413+ def enable_sequential_cpu_offload (self , gpu_id : Optional [ int ] = None , device : Union [torch .device , str ] = "cuda" ):
12921414 r"""
12931415 Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
12941416 dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
12951417 and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
12961418 method called. Offloading happens on a submodule basis. Memory savings are higher than with
12971419 `enable_model_cpu_offload`, but performance is lower.
1420+
1421+ Arguments:
1422+ gpu_id (`int`, *optional*):
1423+ The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1424+ device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1425+ The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1426+ default to "cuda".
12981427 """
12991428 if is_accelerate_available () and is_accelerate_version (">=" , "0.14.0" ):
13001429 from accelerate import cpu_offload
13011430 else :
13021431 raise ImportError ("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" )
13031432
1304- if device == "cuda" :
1305- device = torch .device (f"{ device } :{ gpu_id } " )
1433+ torch_device = torch .device (device )
1434+ device_index = torch_device .index
1435+
1436+ if gpu_id is not None and device_index is not None :
1437+ raise ValueError (
1438+ f"You have passed both `gpu_id`={ gpu_id } and an index as part of the passed device `device`={ device } "
1439+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={ torch_device .type } "
1440+ )
1441+
1442+ # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1443+ self ._offload_gpu_id = gpu_id or torch_device .index or self ._offload_gpu_id or 0
1444+
1445+ device_type = torch_device .type
1446+ device = torch .device (f"{ device_type } :{ self ._offload_gpu_id } " )
13061447
13071448 if self .device .type != "cpu" :
13081449 self .to ("cpu" , silence_dtype_warnings = True )
0 commit comments