Skip to content

Commit 62e847d

Browse files
mjkvaak-amda-r-r-o-wgithub-actions[bot]
authored
Use real-valued instead of complex tensors in Wan2.1 RoPE (huggingface#11649)
* use real instead of complex tensors in Wan2.1 RoPE * remove the redundant type conversion * unpack rotary_emb * register rotary embedding frequencies as non-persistent buffers * Apply style fixes --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4704586 commit 62e847d

File tree

1 file changed

+57
-29
lines changed

1 file changed

+57
-29
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,22 @@ def __call__(
7171

7272
if rotary_emb is not None:
7373

74-
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
75-
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
76-
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
77-
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
78-
return x_out.type_as(hidden_states)
79-
80-
query = apply_rotary_emb(query, rotary_emb)
81-
key = apply_rotary_emb(key, rotary_emb)
74+
def apply_rotary_emb(
75+
hidden_states: torch.Tensor,
76+
freqs_cos: torch.Tensor,
77+
freqs_sin: torch.Tensor,
78+
):
79+
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
80+
x1, x2 = x[..., 0], x[..., 1]
81+
cos = freqs_cos[..., 0::2]
82+
sin = freqs_sin[..., 1::2]
83+
out = torch.empty_like(hidden_states)
84+
out[..., 0::2] = x1 * cos - x2 * sin
85+
out[..., 1::2] = x1 * sin + x2 * cos
86+
return out.type_as(hidden_states)
87+
88+
query = apply_rotary_emb(query, *rotary_emb)
89+
key = apply_rotary_emb(key, *rotary_emb)
8290

8391
# I2V task
8492
hidden_states_img = None
@@ -179,7 +187,11 @@ def forward(
179187

180188
class WanRotaryPosEmbed(nn.Module):
181189
def __init__(
182-
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
190+
self,
191+
attention_head_dim: int,
192+
patch_size: Tuple[int, int, int],
193+
max_seq_len: int,
194+
theta: float = 10000.0,
183195
):
184196
super().__init__()
185197

@@ -189,36 +201,52 @@ def __init__(
189201

190202
h_dim = w_dim = 2 * (attention_head_dim // 6)
191203
t_dim = attention_head_dim - h_dim - w_dim
192-
193-
freqs = []
194204
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
205+
206+
freqs_cos = []
207+
freqs_sin = []
208+
195209
for dim in [t_dim, h_dim, w_dim]:
196-
freq = get_1d_rotary_pos_embed(
197-
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
210+
freq_cos, freq_sin = get_1d_rotary_pos_embed(
211+
dim,
212+
max_seq_len,
213+
theta,
214+
use_real=True,
215+
repeat_interleave_real=True,
216+
freqs_dtype=freqs_dtype,
198217
)
199-
freqs.append(freq)
200-
self.freqs = torch.cat(freqs, dim=1)
218+
freqs_cos.append(freq_cos)
219+
freqs_sin.append(freq_sin)
220+
221+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
222+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
201223

202224
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203225
batch_size, num_channels, num_frames, height, width = hidden_states.shape
204226
p_t, p_h, p_w = self.patch_size
205227
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
206228

207-
freqs = self.freqs.to(hidden_states.device)
208-
freqs = freqs.split_with_sizes(
209-
[
210-
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
211-
self.attention_head_dim // 6,
212-
self.attention_head_dim // 6,
213-
],
214-
dim=1,
215-
)
229+
split_sizes = [
230+
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
231+
self.attention_head_dim // 3,
232+
self.attention_head_dim // 3,
233+
]
234+
235+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
236+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
237+
238+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
239+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
240+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
241+
242+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
243+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
244+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
245+
246+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
247+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
216248

217-
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
218-
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
219-
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
220-
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
221-
return freqs
249+
return freqs_cos, freqs_sin
222250

223251

224252
class WanTransformerBlock(nn.Module):

0 commit comments

Comments
 (0)