Skip to content

kv-cache : fix out-of-bounds view during reserve graph #13547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {

void llama_kv_cache_unified::set_full() {
n = size;

// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
// setting it to 0 is the simplest way to achieve that
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
head = 0;
}

llama_sbatch llama_kv_cache_unified::sbatch_init(
Expand Down Expand Up @@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {

void llama_kv_cache_recurrent::set_full() {
n = size;
head = 0;
}

llama_sbatch llama_kv_cache_recurrent::sbatch_init(
Expand Down
14 changes: 4 additions & 10 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)

// computed before each graph build
Expand Down Expand Up @@ -343,11 +340,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)

// computed before each graph build
Expand Down