-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Refactoring of multi-head attention and support for KV caching #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Started work to make sure all tests pass. |
Tests fail for me that should fail in ?? |
I'll submit a PR with a fix. |
e652799
to
0d6360b
Compare
Could you also link the PR here? |
I need to spend some more work on this one. Sorry was busy with other things. |
474797d
to
b8efa8c
Compare
In my local installation, there is only one test failing which I think I still need to attend to: There are also fails in these tests:
But for them, I don't really understand what is going on. |
OK, same thing in your CI/CD. I'd need help with these two tests, which use mocking in a way I do not understand. The one where "tensors are not close", I can deal with. |
06255ac
to
c8b8895
Compare
sure, which two? |
The failure in |
OK, I fixed |
OK, |
but there is a lot of
on GPU testing |
Cool, let me know what I can do. As I mentioned above, if this big change is too hard to review in one go, I could split it into two, along these lines:
|
yes, that sounds feasible :) |
resolved now we have just 12 failing tests |
There is some mocking of |
This is this code:
From my experience, the |
OK, 6 of the 12 errors are due to expectation that I'll look into the remaining 6 errors. |
OK, I am working on tests which run on multiple GPUs. I missed these locally, but now run on a multi-GPU instance. But the 6 tests which seem to fail due to the |
@Borda , I had to extend |
OK, we are down to 7 errors. 6 of which are about
The one mocking error I'd need help resolving. For the remaining 6, you need to decide. Maybe I see this wrong, but this |
2ed4013
to
98dfe11
Compare
Working on I know now why this passes in |
OK, now we are down to a single test failing, something to do with mocking. My solution here is to expand keys and values in the special case when query and key have the same length. This is what happens during training, or for inference with However, in subsequent cases, where query is much smaller than key, expanding is avoided, to save GPU memory. This turns out to be important for long context inference and fine-tuning. |
@t-vi , this is ready for review. The last test needs to be fixed on your side, I don't understand this mocking. Let me know whether I should split the PR into two. The first would generalize attention, the second introduce the KV cache abstraction. Also let me know if something more fundamental is not OK. |
78be1bf
to
9e80561
Compare
for more information, see https://pre-commit.ci
This continues from #1934 . I created a new branch, because the history of the previous one was messed up with a merge operation.
Adds abstraction for key-value caches, implements batched inference.
I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.
OK, this PR contains the following parts:
In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that.
If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge.
As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well.
Edit: Since I opened this, I am working a lot on gradient computation in the presence of long context models. This is stress-testing the abstraction here quite a bit.