Skip to content

Commit ccc8321

Browse files
ZhengKai91Kai zhenghlkyyiyixuxu
authored
Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)
* get_1d_rotary_pos_embed support npu * Update src/diffusers/models/embeddings.py --------- Co-authored-by: Kai zheng <[email protected]> Co-authored-by: hlky <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 5e48cd2 commit ccc8321

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/diffusers/models/embeddings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,9 @@ def get_1d_rotary_pos_embed(
11541154
/ linear_factor
11551155
) # [D/2]
11561156
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
1157+
is_npu = freqs.device.type == "npu"
1158+
if is_npu:
1159+
freqs = freqs.float()
11571160
if use_real and repeat_interleave_real:
11581161
# flux, hunyuan-dit, cogvideox
11591162
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]

0 commit comments

Comments
 (0)