@@ -71,14 +71,22 @@ def __call__(
7171
7272 if rotary_emb is not None :
7373
74- def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
75- dtype = torch .float32 if hidden_states .device .type == "mps" else torch .float64
76- x_rotated = torch .view_as_complex (hidden_states .to (dtype ).unflatten (3 , (- 1 , 2 )))
77- x_out = torch .view_as_real (x_rotated * freqs ).flatten (3 , 4 )
78- return x_out .type_as (hidden_states )
79-
80- query = apply_rotary_emb (query , rotary_emb )
81- key = apply_rotary_emb (key , rotary_emb )
74+ def apply_rotary_emb (
75+ hidden_states : torch .Tensor ,
76+ freqs_cos : torch .Tensor ,
77+ freqs_sin : torch .Tensor ,
78+ ):
79+ x = hidden_states .view (* hidden_states .shape [:- 1 ], - 1 , 2 )
80+ x1 , x2 = x [..., 0 ], x [..., 1 ]
81+ cos = freqs_cos [..., 0 ::2 ]
82+ sin = freqs_sin [..., 1 ::2 ]
83+ out = torch .empty_like (hidden_states )
84+ out [..., 0 ::2 ] = x1 * cos - x2 * sin
85+ out [..., 1 ::2 ] = x1 * sin + x2 * cos
86+ return out .type_as (hidden_states )
87+
88+ query = apply_rotary_emb (query , * rotary_emb )
89+ key = apply_rotary_emb (key , * rotary_emb )
8290
8391 # I2V task
8492 hidden_states_img = None
@@ -179,7 +187,11 @@ def forward(
179187
180188class WanRotaryPosEmbed (nn .Module ):
181189 def __init__ (
182- self , attention_head_dim : int , patch_size : Tuple [int , int , int ], max_seq_len : int , theta : float = 10000.0
190+ self ,
191+ attention_head_dim : int ,
192+ patch_size : Tuple [int , int , int ],
193+ max_seq_len : int ,
194+ theta : float = 10000.0 ,
183195 ):
184196 super ().__init__ ()
185197
@@ -189,36 +201,52 @@ def __init__(
189201
190202 h_dim = w_dim = 2 * (attention_head_dim // 6 )
191203 t_dim = attention_head_dim - h_dim - w_dim
192-
193- freqs = []
194204 freqs_dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
205+
206+ freqs_cos = []
207+ freqs_sin = []
208+
195209 for dim in [t_dim , h_dim , w_dim ]:
196- freq = get_1d_rotary_pos_embed (
197- dim , max_seq_len , theta , use_real = False , repeat_interleave_real = False , freqs_dtype = freqs_dtype
210+ freq_cos , freq_sin = get_1d_rotary_pos_embed (
211+ dim ,
212+ max_seq_len ,
213+ theta ,
214+ use_real = True ,
215+ repeat_interleave_real = True ,
216+ freqs_dtype = freqs_dtype ,
198217 )
199- freqs .append (freq )
200- self .freqs = torch .cat (freqs , dim = 1 )
218+ freqs_cos .append (freq_cos )
219+ freqs_sin .append (freq_sin )
220+
221+ self .register_buffer ("freqs_cos" , torch .cat (freqs_cos , dim = 1 ), persistent = False )
222+ self .register_buffer ("freqs_sin" , torch .cat (freqs_sin , dim = 1 ), persistent = False )
201223
202224 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
203225 batch_size , num_channels , num_frames , height , width = hidden_states .shape
204226 p_t , p_h , p_w = self .patch_size
205227 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
206228
207- freqs = self .freqs .to (hidden_states .device )
208- freqs = freqs .split_with_sizes (
209- [
210- self .attention_head_dim // 2 - 2 * (self .attention_head_dim // 6 ),
211- self .attention_head_dim // 6 ,
212- self .attention_head_dim // 6 ,
213- ],
214- dim = 1 ,
215- )
229+ split_sizes = [
230+ self .attention_head_dim - 2 * (self .attention_head_dim // 3 ),
231+ self .attention_head_dim // 3 ,
232+ self .attention_head_dim // 3 ,
233+ ]
234+
235+ freqs_cos = self .freqs_cos .split (split_sizes , dim = 1 )
236+ freqs_sin = self .freqs_sin .split (split_sizes , dim = 1 )
237+
238+ freqs_cos_f = freqs_cos [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
239+ freqs_cos_h = freqs_cos [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
240+ freqs_cos_w = freqs_cos [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
241+
242+ freqs_sin_f = freqs_sin [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
243+ freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
244+ freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
245+
246+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
247+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
216248
217- freqs_f = freqs [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
218- freqs_h = freqs [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
219- freqs_w = freqs [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
220- freqs = torch .cat ([freqs_f , freqs_h , freqs_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
221- return freqs
249+ return freqs_cos , freqs_sin
222250
223251
224252class WanTransformerBlock (nn .Module ):
0 commit comments