Skip to content

Commit 3ceaa28

Browse files
pcuencaanton-l
andauthored
Do not use torch.long in mps (huggingface#1488)
* Do not use torch.long in mps Addresses huggingface#1056. * Use torch.int instead of float. * Propagate changes. * Do not silently change float -> int. * Propagate changes. * Apply suggestions from code review Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent a816a87 commit 3ceaa28

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/diffusers/models/unet_2d_condition.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,14 @@ def forward(
299299
timesteps = timestep
300300
if not torch.is_tensor(timesteps):
301301
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
302-
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
303-
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
302+
# This would be a good case for the `match` statement (Python 3.10+)
303+
is_mps = sample.device.type == "mps"
304+
if torch.is_floating_point(timesteps):
305+
dtype = torch.float32 if is_mps else torch.float64
306+
else:
307+
dtype = torch.int32 if is_mps else torch.int64
308+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
309+
elif len(timesteps.shape) == 0:
304310
timesteps = timesteps[None].to(sample.device)
305311

306312
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,14 @@ def forward(
377377
timesteps = timestep
378378
if not torch.is_tensor(timesteps):
379379
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
380-
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
381-
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
380+
# This would be a good case for the `match` statement (Python 3.10+)
381+
is_mps = sample.device.type == "mps"
382+
if torch.is_floating_point(timesteps):
383+
dtype = torch.float32 if is_mps else torch.float64
384+
else:
385+
dtype = torch.int32 if is_mps else torch.int64
386+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
387+
elif len(timesteps.shape) == 0:
382388
timesteps = timesteps[None].to(sample.device)
383389

384390
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML

0 commit comments

Comments
 (0)