You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* split_head_dim flax attn
* Make split_head_dim non default
* make style and make quality
* add description for split_head_dim flag
* Update src/diffusers/models/attention_flax.py
Co-authored-by: Patrick von Platen <[email protected]>
---------
Co-authored-by: Juan Acevedo <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
split_head_dim (`bool`, *optional*, defaults to `False`):
135
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
134
136
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
135
137
Parameters `dtype`
136
138
@@ -140,6 +142,7 @@ class FlaxAttention(nn.Module):
0 commit comments