Skip to content

Refactoring of multi-head attention and support for KV caching #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented May 30, 2025

This continues from #1934 . I created a new branch, because the history of the previous one was messed up with a merge operation.

Adds abstraction for key-value caches, implements batched inference.

I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.

OK, this PR contains the following parts:

  • Small things: Start of layer hook in GPT.forward, skip_lm_head in GPT.forward. I need these for gradient computation, but also to put proper head models on top of the transformer. This is generally useful.
  • Refactoring of multi-head attention: This is needed in order to implement the KV cache abstraction in the way @t-vi suggested (in a phone call). But it also really simplifies things. It also removes a major issue: mask_cache requires lots of memory, it is now computed on demand, with particular attention to inference (where query is much smaller than key)
  • Proper KV cache abstraction, which modifies slightly how GPT.forward is called (namely, input_pos as int). This simplifies things, though. I also provide a few default implementations. DenseKVCache replicates what is currently in place.

In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that.

If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge.

As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well.

Edit: Since I opened this, I am working a lot on gradient computation in the presence of long context models. This is stress-testing the abstraction here quite a bit.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Started work to make sure all tests pass.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

@t-vi , @Borda , just a heads-up, I continue work in this PR, from #1934

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Tests fail for me that should fail in mainline as well. For example, test_against_multimodal_gemma_3 in test_models.py fails in copy_weights_gemma_3, because the skip logic there checks for prefix "vision_tower" or "language_model", but the keys really start with "model.vision_tower" or "model.language_model".

??

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

I'll submit a PR with a fix.

@mseeger mseeger force-pushed the kvcache4 branch 10 times, most recently from e652799 to 0d6360b Compare June 6, 2025 10:16
@Borda
Copy link
Member

Borda commented Jun 10, 2025

I'll submit a PR with a fix.

Could you also link the PR here?

@Borda Borda added enhancement New feature or request waiting on author labels Jun 18, 2025
@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

I need to spend some more work on this one. Sorry was busy with other things.

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 474797d to b8efa8c Compare June 18, 2025 19:06
@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

In my local installation, there is only one test failing which I think I still need to attend to:
FAILED tests/test_adapter.py::test_against_original_gemma_3[device0-dtype0-gemma-3-27b-it] - AssertionError: Tensor-likes are not close!

There are also fails in these tests:

tests/generate/test_adapter.py: Something with mock??
tests/generate/test_main.py: Something with mock??
tests/test_tokenizer.py: Probably only works on CI/CD??

But for them, I don't really understand what is going on.

@mseeger
Copy link
Contributor Author

mseeger commented Jun 18, 2025

OK, same thing in your CI/CD. I'd need help with these two tests, which use mocking in a way I do not understand. The one where "tensors are not close", I can deal with.

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 06255ac to c8b8895 Compare June 20, 2025 10:12
@Borda
Copy link
Member

Borda commented Jun 23, 2025

I'd need help with these two tests, which use mocking in a way I do not understand.

sure, which two?

@mseeger
Copy link
Contributor Author

mseeger commented Jun 30, 2025

The failure in tests/test_attention.py I can fix.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

OK, I fixed test_attention.py. Let's see if this goes through now.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

OK, generate/test_main.py and generate/test_adapter.py still fail.

@Borda
Copy link
Member

Borda commented Jul 1, 2025

OK, generate/test_main.py and generate/test_adapter.py still fail.

shall be resolved now :)
cc: @t-vi @k223kim could you pls review?

@Borda
Copy link
Member

Borda commented Jul 1, 2025

but there is a lot of

E       RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

on GPU testing

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

Cool, let me know what I can do. As I mentioned above, if this big change is too hard to review in one go, I could split it into two, along these lines:

  • Refactoring of attention and clean-ups
  • Introduction of KV cache abstraction and everything related to that

@Borda
Copy link
Member

Borda commented Jul 1, 2025

I could split it into two, along these lines

yes, that sounds feasible :)

@Borda
Copy link
Member

Borda commented Jul 1, 2025

but there is a lot of

E       RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

on GPU testing

resolved now we have just 12 failing tests

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

There is some mocking of litgpt.attention.scaled_dot_product_attention in test_model, that I can fix.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-2b] - assert False
 +  where False = mem_efficient_sdp_enabled()
FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-9b] - assert False
 +  where False = mem_efficient_sdp_enabled()
FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-27b] - assert False
 +  where False = mem_efficient_sdp_enabled()
FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-2b-it] - assert False
 +  where False = mem_efficient_sdp_enabled()
FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-9b-it] - assert False
 +  where False = mem_efficient_sdp_enabled()
FAILED tests/test_model.py::test_sdpa_choice_kv_cache[Gemma-2-27b-it] - assert False
 +  where False = mem_efficient_sdp_enabled()

This is this code:

    def assert_sdpa_backend(original_fn, query, k_and_v, scale, mask, attention_logit_softcapping):
        # SDPAParams gained an additional argument in PyTorch 2.5
        args = []
        assert k_and_v.both_in_parallel()
        if hasattr(SDPAParams, "enable_gqa"):
            args.append(False)
        params = SDPAParams(query, k_and_v.keys(), k_and_v.values(), mask, 0.0, True, *args)
        if expected is SDPBackend.FLASH_ATTENTION:
            assert flash_sdp_enabled()
            assert can_use_flash_attention(params, True)
        elif expected is SDPBackend.EFFICIENT_ATTENTION:
>           assert mem_efficient_sdp_enabled()
E           assert False
E            +  where False = mem_efficient_sdp_enabled()

From my experience, the EFFICIENT_ATTENTION kernel is often not available.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 1, 2025

OK, 6 of the 12 errors are due to expectation that EFFICIENT_ATTENTION kernel is present (maybe I am wrong). To me, this is not always the case (I rather saw flash attention to be available). Maybe this needs to be specialized as well.

I'll look into the remaining 6 errors.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 2, 2025

OK, I am working on tests which run on multiple GPUs. I missed these locally, but now run on a multi-GPU instance.

But the 6 tests which seem to fail due to the EFFICIENT_ATTENTION kernel not available, there needs to be a decision. Did these always work in the past?

@mseeger
Copy link
Contributor Author

mseeger commented Jul 2, 2025

@Borda , I had to extend generate/sequentially.py a bit in order to make some tests pass. I think this may be interesting on its own, it is making sure that layer_to_device always works.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 2, 2025

OK, we are down to 7 errors. 6 of which are about EFFICIENT_ATTENTION, and one more using mocking:

FAILED tests/test_generate_speculatively.py::test_main - AssertionError: assert [call(_Fabric...culative_k=3)] == [call(, ...culative_k=3)]

The one mocking error I'd need help resolving.

For the remaining 6, you need to decide. Maybe I see this wrong, but this EFFICIENT_ATTENTION kernel is not available on the GPU instances I tried on. Maybe I am doing something wrong, but it seems to be the same in your CI/CD system. Just curious why this did not happen before?

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 2ed4013 to 98dfe11 Compare July 4, 2025 08:42
@mseeger
Copy link
Contributor Author

mseeger commented Jul 4, 2025

Working on test_model.py. There is something else going on here, seems like enable_gqa=True does not work with flash attention. Need to investigate.

I know now why this passes in main, because you do not use enable_gqa=True there and instead expand the key and value matrices. That is not great, bit wasteful of memory.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 5, 2025

OK, now we are down to a single test failing, something to do with mocking.

My solution here is to expand keys and values in the special case when query and key have the same length. This is what happens during training, or for inference with input_pos=0. In this case, it is more important to make use of the efficient fused kernels than to save some memory not expanding the inputs.

However, in subsequent cases, where query is much smaller than key, expanding is avoided, to save GPU memory. This turns out to be important for long context inference and fine-tuning.

@mseeger
Copy link
Contributor Author

mseeger commented Jul 7, 2025

@t-vi , this is ready for review. The last test needs to be fixed on your side, I don't understand this mocking.

Let me know whether I should split the PR into two. The first would generalize attention, the second introduce the KV cache abstraction.

Also let me know if something more fundamental is not OK.

@mseeger mseeger force-pushed the kvcache4 branch 2 times, most recently from 78be1bf to 9e80561 Compare July 9, 2025 12:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants