diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3dcad65bb6a85..265db2527c7ca 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) { void llama_kv_cache_unified::set_full() { n = size; + + // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not + // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. + // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so + // setting it to 0 is the simplest way to achieve that + // ref: https://github.com/ggml-org/llama.cpp/issues/13359 + head = 0; } llama_sbatch llama_kv_cache_unified::sbatch_init( @@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) { void llama_kv_cache_recurrent::set_full() { n = size; + head = 0; } llama_sbatch llama_kv_cache_recurrent::sbatch_init( diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf3b4b6a4430f..e83e12c09f2b1 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -171,11 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build @@ -343,11 +340,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build