Skip to content

feat: Hybrid unified/recurrent cache #13276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

gabe-l-hart
Copy link
Contributor

Description

This implementation covers both llama_memory_i and llama_kv_cache interfaces, but they could very well not be correct.

Discussion

I'm putting this up for discussion even though it doesn't have much value as standalone. My ultimate goal is support for the just-released granite 4 which is a combination of mamba2 and granitemoeshared layers. I opened #13275 to track the full scope of model architecture changes.

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome to see this progress!

Comment on lines 2459 to 2460
// TODO: Will it cause problems if some caches are able to remove the seq
// but others aren't?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will cause problems if this breaks the coherency between caches. (e.g. part of a sequence is removed in one cache but not the other).

This is what I was referring to in #12799 (comment) when I wrote:

The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way).

I think the seq_rm API might fundamentally be too specific to self-attention KV cache. Recurrent models can't rollback their state, because intermediate states are not kept since keeping them for all tokens would take too much space. (when seq_rm returns false, it means the states have to be re-calculated from scratch for the affected sequence (at least that was the intention in #5328))

Ideally, if there was some API to create snapshots and rollback to them, the implementation would be simpler for recurrent models (and for hybrid models by extension). (technically, sequences (with seq_id) already kind of do this (and are copy-on-write), but snapshots within sequences might be more convenient to manage in user code, since managing which state is the latest per sequence could be done transparently)

But that would also mean having to manage the lifetime of explicit state snapshots (in examples/server/server.cpp among others) instead of directly dealing with ranges of token positions (and might make things like largest-common-prefix context caching harder to handle). I've previously shared some ideas about state snapshots/checkpoints in #7531 (comment) (although the first half of the comment is about session restore as in state_read).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, interesting. I'm definitely still learning on-the-fly here, but based on this description and the logic here in server.cpp, it seems like the most correct implementation would be to leak implementation details of the child caches or introduce a new member API for can_seq_rm that is const but returns the same logic. I think I'll give that a shot and see how far I can get.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've pushed an attempt at doing this safely. One thing I noticed is that these mutating methods don't seem to have any sort of locking mechanism, so the way I have it implemented could certainly be prone to thread safety problems if concurrent threads tried to call seq_rm. I don't think this is any different than for the current cache implementations since those would also be sensitive to the same races where the validated condition changes after validation but before the members get mutated, but I wanted to double check if this kind of thread safety is guarded against elsewhere (or just assumed to be handled in the client layer).

Comment on lines 2533 to 2534
// If any of the caches are recurrent, require simple split
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
Copy link
Collaborator

@compilade compilade May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple split should not be used with recurrent models, they expect equal split.

See #7531 (comment) which illustrates the splits

Suggested change
// If any of the caches are recurrent, require simple split
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
// If any of the caches are recurrent, require non-simple split
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment pointer, this is super helpful for understanding what the consequences of these actually are!

Comment on lines 2538 to 2540
if (m_has_recurrent) {
return sbatch.split_simple(n_ubatch);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work, recurrent models expect split_equal to be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'm following now. I had them backwards in my head

Comment on lines +2586 to +2651
// TODO: Is this correct?
// If any children can shift, return true
for (const auto & cache : m_children) {
if (cache->get_can_shift()) {
return true;
}
}
Copy link
Collaborator

@compilade compilade May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be if all children can shift, then return true.

But as you've noticed elsewhere, can_shift should technically always be true for all currently-implemented cache types, so I don't know if that part of the API will stay anyway.

This was referenced May 13, 2025
@gabe-l-hart gabe-l-hart marked this pull request as ready for review May 14, 2025 20:13
@gabe-l-hart
Copy link
Contributor Author

@compilade I now have a proof-point (#13550) that this works to some extend, though I haven't tested it robustly for edge cases. There are a few additional changes I needed to make on that branch that should maybe come over to this branch, but it gets a little hairy because they interact with adding the actual model architectures. Some possible paths forward:

  1. Review this one as-is and take the updates from the Granite 4 branch when they come
  2. Try to move the changes to hparams to determine per-layer recurrence over to this branch, and then just have unreachable code when initializing the hybrid cache in llama-model.cpp
  3. Scrap this PR and just review all hybrid cache changes with the incoming models in the Granite 4 branch

Thoughts / Preferences?

This implementation covers both `llama_memory_i` and `llama_kv_cache`
interfaces, but they could very well not be correct.

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
… seq_rm

This allows the hybrid cache to check first before mutating any of the
children.

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
The parent should fully own the lifecycle of the children which is managed
by the m_children member holding unique_ptrs. These need to be initialized
correctly, so the constructor now takes the input vector of child_cache by
value instead of reference so that the child pointers can be transferred to
the parent cache. The expectation is that the vector of child_cache
instances will be instantiated in-place with move semantics.

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…l is recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…ches

This is a bit of an inversion of concerns, so we could conceivably make the
interface to this more opaque to the other cache types by providing
something like a layer mask, but since these cache implementations already
have access to the hparams, it seems minimally invasive to just check the
new recurrent_layer function.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…s in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
There is a small breaking change here that extends the create_memory
method signature to include the hparams. Currently, this member is only
used inside llama_context and is not part of an interface that's expected
to be extended by classes derived from llama_model, so I don't think this
should actually break any downstream use cases.

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart changed the title feat: First pass at llama_kv_cache_hybrid feat: Hybrid unified/recurrent cache May 16, 2025
@gabe-l-hart
Copy link
Contributor Author

I thought about it further and decided that the cleanest separation between this and the Granite 4 branch is to pull over the key parts of hparam parsing and kv cache instantiation and isolate the mismatch to the llm_arch_is_hybrid function which will always return false for now until there are real hybrid models.

@@ -402,7 +402,10 @@ struct llama_model {

// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
llama_memory_i * create_memory(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that this is entirely unnecessary since llama_model already has hparams as a member variable 🤦

It was already available as a member! 🤦

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
This will be the public interface used by functions that need to access one
specific type of child. It's a bit brittle since the rest of the hybrid
class intentionally avoids expecting there to be exactly one unified child
and one recurrent child, but the idea is that this should only be used from
a context where that's known to be true.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…he constructor

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
So far this only tests constructor logic (and barely that)

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
@github-actions github-actions bot added the testing Everything test related label May 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants