Skip to content

Commit e86a280

Browse files
authored
Remove warning about half precision on MPS (huggingface#1163)
Remove warning about half precision on MPS.
1 parent b4a1ed8 commit e86a280

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,13 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
209209
for name in module_names.keys():
210210
module = getattr(self, name)
211211
if isinstance(module, torch.nn.Module):
212-
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
212+
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
213213
logger.warning(
214-
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
215-
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
216-
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
217-
" `float16` operations on those devices in PyTorch. Please remove the"
218-
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
214+
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
215+
" is not recommended to move them to `cpu` as running them will fail. Please make"
216+
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
217+
" support for`float16` operations on this device in PyTorch. Please, remove the"
218+
" `torch_dtype=torch.float16` argument, or use another device for inference."
219219
)
220220
module.to(torch_device)
221221
return self

0 commit comments

Comments
 (0)