@@ -243,7 +243,9 @@ def __init__(
243243 self .norm_msa_a = AdaLayerNormZero (dim , device = device , dtype = dtype )
244244 self .norm_mlp_a = AdaLayerNormZero (dim , device = device , dtype = dtype )
245245 self .ff_a = nn .Sequential (
246- nn .Linear (dim , dim * 4 ), nn .GELU (approximate = "tanh" ), nn .Linear (dim * 4 , dim , device = device , dtype = dtype )
246+ nn .Linear (dim , dim * 4 , device = device , dtype = dtype ),
247+ nn .GELU (approximate = "tanh" ),
248+ nn .Linear (dim * 4 , dim , device = device , dtype = dtype )
247249 )
248250 # Text
249251 self .norm_msa_b = AdaLayerNormZero (dim , device = device , dtype = dtype )
@@ -313,10 +315,10 @@ def __init__(
313315 self .norm = AdaLayerNormZero (dim , device = device , dtype = dtype )
314316 self .attn = FluxSingleAttention (dim , num_heads , attn_kwargs = attn_kwargs , device = device , dtype = dtype )
315317 self .mlp = nn .Sequential (
316- nn .Linear (dim , dim * 4 ),
318+ nn .Linear (dim , dim * 4 , device = device , dtype = dtype ),
317319 nn .GELU (approximate = "tanh" ),
318320 )
319- self .proj_out = nn .Linear (dim * 5 , dim )
321+ self .proj_out = nn .Linear (dim * 5 , dim , device = device , dtype = dtype )
320322
321323 def forward (self , x , t_emb , rope_emb , image_emb = None ):
322324 h , gate = self .norm (x , emb = t_emb )
0 commit comments