Skip to content

support flash-attn at torch backend #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

pass-lin
Copy link
Contributor

restart from #2189
Let's try to make torch run flash attn together.

@sachinprasadhs sachinprasadhs added the kokoro:force-run Runs Tests on GPU label May 19, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 19, 2025
@pass-lin
Copy link
Contributor Author

This bug doesn't seem to be relevant to me because I haven't made any relevant changes
@sachinprasadhs

@mattdangerw
Copy link
Member

@divyashreepathihalli can you take a look at this one?

self._num_key_value_heads = num_key_value_heads
self._sliding_window = sliding_window
self._dropout = dropout
self.num_query_heads = num_query_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason behind the renaming?

Copy link
Contributor Author

@pass-lin pass-lin May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason behind the renaming?

https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/mixtral/mixtral_attention.py
I'm just synchronizing it to the current repository here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, can you rebase your branch with master so that these dont show up as new changes


def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
# GPU never supports softcap in the fused op.
if self.logit_soft_cap is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to return false in JAX backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to return false in JAX backend.

mixtral never use self.logit_soft_cap? so I can not get your mean.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! okay

@@ -71,6 +71,23 @@ def fused_attention_op_available():
)
return False
return True
elif (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101
in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

Copy link
Contributor Author

@pass-lin pass-lin May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good! Can you please enable this https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101 in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)

image
These are models that reference the fused_attention_op_available() function.
Here are the test results of A100.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.

I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch.
I propose, can you design tests on models like qwen and llama that are more suitable for flash-attn?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about it, I'll have to look into it

Copy link
Contributor Author

@pass-lin pass-lin May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - have you tested this? can you please take a look?

@pctablet505 @divyashreepathihalli
I can make sure this test is wrong, because it is testing gemma2, and gemm2 does not support flash-attn.

Copy link
Collaborator

@pctablet505 pctablet505 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pass-lin
I just verified that Gemma2 and Gemma3 can't support Flash_attention on A100 GPU.
Gemma3 can use flash attention on TPU or GPUs with cuda compute capability >=9.0 that is H series or latter. For example H100

#21333

@divyashreepathihalli
Copy link
Collaborator

Thanks for the PR, I left some comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants