-
Notifications
You must be signed in to change notification settings - Fork 283
support flash-attn at torch backend #2257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This bug doesn't seem to be relevant to me because I haven't made any relevant changes |
@divyashreepathihalli can you take a look at this one? |
self._num_key_value_heads = num_key_value_heads | ||
self._sliding_window = sliding_window | ||
self._dropout = dropout | ||
self.num_query_heads = num_query_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason behind the renaming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason behind the renaming?
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/mixtral/mixtral_attention.py
I'm just synchronizing it to the current repository here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, can you rebase your branch with master so that these dont show up as new changes
|
||
def _use_fused_attention_op(self): | ||
if not fused_attention_op_available(): | ||
return False | ||
if self.dropout > 0.0: | ||
return False | ||
if running_on_gpu(): | ||
# GPU never supports softcap in the fused op. | ||
if self.logit_soft_cap is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to return false in JAX backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to return false in JAX backend.
mixtral never use self.logit_soft_cap? so I can not get your mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! okay
@@ -71,6 +71,23 @@ def fused_attention_op_available(): | |||
) | |||
return False | |||
return True | |||
elif ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks good! Can you please enable this
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101
in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks good! Can you please enable this https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/gemma/gemma_causal_lm_test.py#L101 in PyTorch backend and make sure the tests pass in the supported GPU - ( this may not be supported on T4-which our CI tests use, so a demo colab showing the tests passing on a supported GPU would be great)
These are models that reference the fused_attention_op_available() function.
Here are the test results of A100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pass-lin the test has not been enabled on Pytorch backend. Can you please refer to the above comment on enabling it.
I don't know if you have tested it on a100. At present, the gemma and gemma3 test code flash attn fails. This is true for both jax and torch.
I propose, can you design tests on models like qwen and llama that are more suitable for flash-attn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pctablet505 - have you tested this? can you please take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about it, I'll have to look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pctablet505 - have you tested this? can you please take a look?
@pctablet505 @divyashreepathihalli
I can make sure this test is wrong, because it is testing gemma2, and gemm2 does not support flash-attn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, I left some comments. |
restart from #2189
Let's try to make torch run flash attn together.