@@ -138,14 +138,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
138138 self .norm2 = FP32LayerNorm (dim , elementwise_affine = False , bias = False )
139139 self .ff = AuraFlowFeedForward (dim , dim * 4 )
140140
141- def forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor , i = 9999 ):
141+ def forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor ):
142142 residual = hidden_states
143143
144144 # Norm + Projection.
145145 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
146146
147147 # Attention.
148- attn_output = self .attn (hidden_states = norm_hidden_states , i = i )
148+ attn_output = self .attn (hidden_states = norm_hidden_states )
149149
150150 # Process attention outputs for the `hidden_states`.
151151 hidden_states = self .norm2 (residual + gate_msa .unsqueeze (1 ) * attn_output )
@@ -201,7 +201,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
201201 self .ff_context = AuraFlowFeedForward (dim , dim * 4 )
202202
203203 def forward (
204- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor , i = 0
204+ self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
205205 ):
206206 residual = hidden_states
207207 residual_context = encoder_hidden_states
@@ -214,7 +214,7 @@ def forward(
214214
215215 # Attention.
216216 attn_output , context_attn_output = self .attn (
217- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , i = i
217+ hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
218218 )
219219
220220 # Process attention outputs for the `hidden_states`.
@@ -366,7 +366,7 @@ def custom_forward(*inputs):
366366
367367 else :
368368 encoder_hidden_states , hidden_states = block (
369- hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb , i = index_block
369+ hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb
370370 )
371371
372372 # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
0 commit comments