Skip to content

Commit 56c0037

Browse files
authored
Use expand instead of ones to broadcast tensor (huggingface#373)
Use `expand` instead of ones to broadcast tensor. As suggested by @bes-dev. According the documentation this shouldn't take any memory - it just plays with the strides.
1 parent 7a1229f commit 56c0037

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def forward(
152152
timesteps = timesteps[None].to(sample.device)
153153

154154
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
155-
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
155+
timesteps = timesteps.expand(sample.shape[0])
156156

157157
t_emb = self.time_proj(timesteps)
158158
emb = self.time_embedding(t_emb)

0 commit comments

Comments
 (0)