Skip to content

Commit 30a512e

Browse files
patrickvonplatenyiyixuxupcuenca
authored
[Core] Improve .to(...) method, fix offloads multi-gpu, add docstring, add dtype (huggingface#5132)
* fix cpu offload * fix * fix * Update src/diffusers/pipelines/pipeline_utils.py * make style * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * fix more * fix more --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 92f15f5 commit 30a512e

File tree

3 files changed

+249
-21
lines changed

3 files changed

+249
-21
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 162 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -670,14 +670,98 @@ def is_saveable_module(name, value):
670670
create_pr=create_pr,
671671
)
672672

673-
def to(
674-
self,
675-
torch_device: Optional[Union[str, torch.device]] = None,
676-
torch_dtype: Optional[torch.dtype] = None,
677-
silence_dtype_warnings: bool = False,
678-
):
679-
if torch_device is None and torch_dtype is None:
680-
return self
673+
def to(self, *args, **kwargs):
674+
r"""
675+
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
676+
arguments of `self.to(*args, **kwargs).`
677+
678+
<Tip>
679+
680+
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
681+
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
682+
683+
</Tip>
684+
685+
686+
Here are the ways to call `to`:
687+
688+
- `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
689+
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
690+
- `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
691+
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
692+
- `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
693+
specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
694+
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
695+
696+
Arguments:
697+
dtype (`torch.dtype`, *optional*):
698+
Returns a pipeline with the specified
699+
[`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
700+
device (`torch.Device`, *optional*):
701+
Returns a pipeline with the specified
702+
[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
703+
silence_dtype_warnings (`str`, *optional*, defaults to `False`):
704+
Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
705+
706+
Returns:
707+
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
708+
"""
709+
710+
torch_dtype = kwargs.pop("torch_dtype", None)
711+
if torch_dtype is not None:
712+
deprecate("torch_dtype", "0.25.0", "")
713+
torch_device = kwargs.pop("torch_device", None)
714+
if torch_device is not None:
715+
deprecate("torch_device", "0.25.0", "")
716+
717+
dtype_kwarg = kwargs.pop("dtype", None)
718+
device_kwarg = kwargs.pop("device", None)
719+
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
720+
721+
if torch_dtype is not None and dtype_kwarg is not None:
722+
raise ValueError(
723+
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
724+
)
725+
726+
dtype = torch_dtype or dtype_kwarg
727+
728+
if torch_device is not None and device_kwarg is not None:
729+
raise ValueError(
730+
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
731+
)
732+
733+
device = torch_device or device_kwarg
734+
735+
dtype_arg = None
736+
device_arg = None
737+
if len(args) == 1:
738+
if isinstance(args[0], torch.dtype):
739+
dtype_arg = args[0]
740+
else:
741+
device_arg = torch.device(args[0]) if args[0] is not None else None
742+
elif len(args) == 2:
743+
if isinstance(args[0], torch.dtype):
744+
raise ValueError(
745+
"When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
746+
)
747+
device_arg = torch.device(args[0]) if args[0] is not None else None
748+
dtype_arg = args[1]
749+
elif len(args) > 2:
750+
raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
751+
752+
if dtype is not None and dtype_arg is not None:
753+
raise ValueError(
754+
"You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
755+
)
756+
757+
dtype = dtype or dtype_arg
758+
759+
if device is not None and device_arg is not None:
760+
raise ValueError(
761+
"You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
762+
)
763+
764+
device = device or device_arg
681765

682766
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
683767
def module_is_sequentially_offloaded(module):
@@ -698,14 +782,14 @@ def module_is_offloaded(module):
698782
pipeline_is_sequentially_offloaded = any(
699783
module_is_sequentially_offloaded(module) for _, module in self.components.items()
700784
)
701-
if pipeline_is_sequentially_offloaded and torch_device and torch.device(torch_device).type == "cuda":
785+
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
702786
raise ValueError(
703787
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
704788
)
705789

706790
# Display a warning in this case (the operation succeeds but the benefits are lost)
707791
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
708-
if pipeline_is_offloaded and torch_device and torch.device(torch_device).type == "cuda":
792+
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
709793
logger.warning(
710794
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
711795
)
@@ -718,26 +802,26 @@ def module_is_offloaded(module):
718802
for module in modules:
719803
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
720804

721-
if is_loaded_in_8bit and torch_dtype is not None:
805+
if is_loaded_in_8bit and dtype is not None:
722806
logger.warning(
723807
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
724808
)
725809

726-
if is_loaded_in_8bit and torch_device is not None:
810+
if is_loaded_in_8bit and device is not None:
727811
logger.warning(
728812
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
729813
)
730814
else:
731-
module.to(torch_device, torch_dtype)
815+
module.to(device, dtype)
732816

733817
if (
734818
module.dtype == torch.float16
735-
and str(torch_device) in ["cpu"]
819+
and str(device) in ["cpu"]
736820
and not silence_dtype_warnings
737821
and not is_offloaded
738822
):
739823
logger.warning(
740-
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
824+
"Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
741825
" is not recommended to move them to `cpu` as running them will fail. Please make"
742826
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
743827
" support for`float16` operations on this device in PyTorch. Please, remove the"
@@ -760,6 +844,21 @@ def device(self) -> torch.device:
760844

761845
return torch.device("cpu")
762846

847+
@property
848+
def dtype(self) -> torch.dtype:
849+
r"""
850+
Returns:
851+
`torch.dtype`: The torch dtype on which the pipeline is located.
852+
"""
853+
module_names, _ = self._get_signature_keys(self)
854+
modules = [getattr(self, n, None) for n in module_names]
855+
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
856+
857+
for module in modules:
858+
return module.dtype
859+
860+
return torch.float32
861+
763862
@classmethod
764863
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
765864
r"""
@@ -1222,12 +1321,19 @@ def _execution_device(self):
12221321
return torch.device(module._hf_hook.execution_device)
12231322
return self.device
12241323

1225-
def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
1324+
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
12261325
r"""
12271326
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
12281327
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
12291328
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
12301329
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1330+
1331+
Arguments:
1332+
gpu_id (`int`, *optional*):
1333+
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1334+
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1335+
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1336+
default to "cuda".
12311337
"""
12321338
if self.model_cpu_offload_seq is None:
12331339
raise ValueError(
@@ -1239,7 +1345,20 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,
12391345
else:
12401346
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
12411347

1242-
device = torch.device(f"cuda:{gpu_id}")
1348+
torch_device = torch.device(device)
1349+
device_index = torch_device.index
1350+
1351+
if gpu_id is not None and device_index is not None:
1352+
raise ValueError(
1353+
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
1354+
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
1355+
)
1356+
1357+
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1358+
self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0
1359+
1360+
device_type = torch_device.type
1361+
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
12431362

12441363
if self.device.type != "cpu":
12451364
self.to("cpu", silence_dtype_warnings=True)
@@ -1274,7 +1393,10 @@ def enable_model_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device,
12741393

12751394
def maybe_free_model_hooks(self):
12761395
r"""
1277-
TODO: Better doc string
1396+
Function that offloads all components, removes all model hooks that were added when using
1397+
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
1398+
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
1399+
functions correctly when applying enable_model_cpu_offload.
12781400
"""
12791401
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
12801402
# `enable_model_cpu_offload` has not be called, so silently do nothing
@@ -1288,21 +1410,40 @@ def maybe_free_model_hooks(self):
12881410
# make sure the model is in the same state as before calling it
12891411
self.enable_model_cpu_offload()
12901412

1291-
def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
1413+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
12921414
r"""
12931415
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
12941416
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
12951417
and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
12961418
method called. Offloading happens on a submodule basis. Memory savings are higher than with
12971419
`enable_model_cpu_offload`, but performance is lower.
1420+
1421+
Arguments:
1422+
gpu_id (`int`, *optional*):
1423+
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1424+
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1425+
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1426+
default to "cuda".
12981427
"""
12991428
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
13001429
from accelerate import cpu_offload
13011430
else:
13021431
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
13031432

1304-
if device == "cuda":
1305-
device = torch.device(f"{device}:{gpu_id}")
1433+
torch_device = torch.device(device)
1434+
device_index = torch_device.index
1435+
1436+
if gpu_id is not None and device_index is not None:
1437+
raise ValueError(
1438+
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
1439+
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
1440+
)
1441+
1442+
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1443+
self._offload_gpu_id = gpu_id or torch_device.index or self._offload_gpu_id or 0
1444+
1445+
device_type = torch_device.type
1446+
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
13061447

13071448
if self.device.type != "cpu":
13081449
self.to("cpu", silence_dtype_warnings=True)

tests/pipelines/test_pipelines.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,89 @@ def get_all_filenames(directory):
14751475
assert len(variant_model_files) == 0
14761476
assert len(all_model_files) > 0
14771477

1478+
def test_pipe_to(self):
1479+
unet = self.dummy_cond_unet()
1480+
scheduler = PNDMScheduler(skip_prk_steps=True)
1481+
vae = self.dummy_vae
1482+
bert = self.dummy_text_encoder
1483+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1484+
1485+
sd = StableDiffusionPipeline(
1486+
unet=unet,
1487+
scheduler=scheduler,
1488+
vae=vae,
1489+
text_encoder=bert,
1490+
tokenizer=tokenizer,
1491+
safety_checker=None,
1492+
feature_extractor=self.dummy_extractor,
1493+
)
1494+
1495+
device_type = torch.device(torch_device).type
1496+
1497+
sd1 = sd.to(device_type)
1498+
sd2 = sd.to(torch.device(device_type))
1499+
sd3 = sd.to(device_type, torch.float32)
1500+
sd4 = sd.to(device=device_type)
1501+
sd5 = sd.to(torch_device=device_type)
1502+
sd6 = sd.to(device_type, dtype=torch.float32)
1503+
sd7 = sd.to(device_type, torch_dtype=torch.float32)
1504+
1505+
assert sd1.device.type == device_type
1506+
assert sd2.device.type == device_type
1507+
assert sd3.device.type == device_type
1508+
assert sd4.device.type == device_type
1509+
assert sd5.device.type == device_type
1510+
assert sd6.device.type == device_type
1511+
assert sd7.device.type == device_type
1512+
1513+
sd1 = sd.to(torch.float16)
1514+
sd2 = sd.to(None, torch.float16)
1515+
sd3 = sd.to(dtype=torch.float16)
1516+
sd4 = sd.to(torch_dtype=torch.float16)
1517+
sd5 = sd.to(None, dtype=torch.float16)
1518+
sd6 = sd.to(None, torch_dtype=torch.float16)
1519+
1520+
assert sd1.dtype == torch.float16
1521+
assert sd2.dtype == torch.float16
1522+
assert sd3.dtype == torch.float16
1523+
assert sd4.dtype == torch.float16
1524+
assert sd5.dtype == torch.float16
1525+
assert sd6.dtype == torch.float16
1526+
1527+
sd1 = sd.to(device=device_type, dtype=torch.float16)
1528+
sd2 = sd.to(torch_device=device_type, torch_dtype=torch.float16)
1529+
sd3 = sd.to(device_type, torch.float16)
1530+
1531+
assert sd1.dtype == torch.float16
1532+
assert sd2.dtype == torch.float16
1533+
assert sd3.dtype == torch.float16
1534+
1535+
assert sd1.device.type == device_type
1536+
assert sd2.device.type == device_type
1537+
assert sd3.device.type == device_type
1538+
1539+
def test_pipe_same_device_id_offload(self):
1540+
unet = self.dummy_cond_unet()
1541+
scheduler = PNDMScheduler(skip_prk_steps=True)
1542+
vae = self.dummy_vae
1543+
bert = self.dummy_text_encoder
1544+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1545+
1546+
sd = StableDiffusionPipeline(
1547+
unet=unet,
1548+
scheduler=scheduler,
1549+
vae=vae,
1550+
text_encoder=bert,
1551+
tokenizer=tokenizer,
1552+
safety_checker=None,
1553+
feature_extractor=self.dummy_extractor,
1554+
)
1555+
1556+
sd.enable_model_cpu_offload(gpu_id=5)
1557+
assert sd._offload_gpu_id == 5
1558+
sd.maybe_free_model_hooks()
1559+
assert sd._offload_gpu_id == 5
1560+
14781561

14791562
@slow
14801563
@require_torch_gpu

tests/pipelines/text_to_video/test_video_to_video.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def test_save_load_optional_components(self):
165165
def test_dict_tuple_outputs_equivalent(self):
166166
super().test_dict_tuple_outputs_equivalent()
167167

168+
@is_flaky()
169+
def test_save_load_local(self):
170+
super().test_save_load_local()
171+
168172
@unittest.skipIf(
169173
torch_device != "cuda" or not is_xformers_available(),
170174
reason="XFormers attention is only available with CUDA and `xformers` installed",

0 commit comments

Comments
 (0)