Skip to content

Commit 16d56c4

Browse files
entrpnjfacevedo-googlepatrickvonplaten
authored
F/flax split head dim (huggingface#5181)
* split_head_dim flax attn * Make split_head_dim non default * make style and make quality * add description for split_head_dim flag * Update src/diffusers/models/attention_flax.py Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Juan Acevedo <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent c82f7ba commit 16d56c4

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class FlaxAttention(nn.Module):
131131
Dropout rate
132132
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133133
enable memory efficient attention https://arxiv.org/abs/2112.05682
134+
split_head_dim (`bool`, *optional*, defaults to `False`):
135+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
134136
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
135137
Parameters `dtype`
136138
@@ -140,6 +142,7 @@ class FlaxAttention(nn.Module):
140142
dim_head: int = 64
141143
dropout: float = 0.0
142144
use_memory_efficient_attention: bool = False
145+
split_head_dim: bool = False
143146
dtype: jnp.dtype = jnp.float32
144147

145148
def setup(self):
@@ -177,9 +180,15 @@ def __call__(self, hidden_states, context=None, deterministic=True):
177180
key_proj = self.key(context)
178181
value_proj = self.value(context)
179182

180-
query_states = self.reshape_heads_to_batch_dim(query_proj)
181-
key_states = self.reshape_heads_to_batch_dim(key_proj)
182-
value_states = self.reshape_heads_to_batch_dim(value_proj)
183+
if self.split_head_dim:
184+
b = hidden_states.shape[0]
185+
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
186+
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
187+
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
188+
else:
189+
query_states = self.reshape_heads_to_batch_dim(query_proj)
190+
key_states = self.reshape_heads_to_batch_dim(key_proj)
191+
value_states = self.reshape_heads_to_batch_dim(value_proj)
183192

184193
if self.use_memory_efficient_attention:
185194
query_states = query_states.transpose(1, 0, 2)
@@ -206,14 +215,23 @@ def __call__(self, hidden_states, context=None, deterministic=True):
206215
hidden_states = hidden_states.transpose(1, 0, 2)
207216
else:
208217
# compute attentions
209-
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
218+
if self.split_head_dim:
219+
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
220+
else:
221+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
222+
210223
attention_scores = attention_scores * self.scale
211-
attention_probs = nn.softmax(attention_scores, axis=2)
224+
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
212225

213226
# attend to values
214-
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
227+
if self.split_head_dim:
228+
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
229+
b = hidden_states.shape[0]
230+
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
231+
else:
232+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
233+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
215234

216-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217235
hidden_states = self.proj_attn(hidden_states)
218236
return self.dropout_layer(hidden_states, deterministic=deterministic)
219237

0 commit comments

Comments
 (0)