Skip to content

Commit c990396

Browse files
authored
Merge pull request #1 from jsmidt/jsmidt-patch-1
Update transformer_flux.py. Change float64 to float32
2 parents cee7c1b + ed0c49b commit c990396

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
3939
assert dim % 2 == 0, "The dimension must be even."
4040

41-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
41+
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
4242
omega = 1.0 / (theta**scale)
4343

4444
batch_size, seq_length = pos.shape

0 commit comments

Comments
 (0)