Skip to content

kv-cache : add SWA support #13194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
kv-cache : apply defrag when we fail to find slots for the batch
ggml-ci
  • Loading branch information
ggerganov committed May 17, 2025
commit 63901253e8132d45cc9a6394043664b31f50d156
9 changes: 8 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ llama_context::llama_context(
}

cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);

cparams.op_offload = params.op_offload;

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
Expand Down Expand Up @@ -2637,7 +2638,13 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
int ret = ctx->decode(batch);

if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);
}

if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
Expand Down
136 changes: 88 additions & 48 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,44 +333,31 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
}

void llama_kv_cache_unified::restore() {
if (pending.ubatches.empty()) {
return;
}

uint32_t new_head = size;

for (const auto & ubatch : pending.ubatches) {
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];

cells[ubatch.head + i].seq_id.erase(seq_id);
if (cells[ubatch.head + i].seq_id.empty()) {
used--;

new_head = std::min(new_head, ubatch.head + i);
}
for (const auto & [id, cell] : recovery.cells) {
// TODO: move to new `struct kv_cells`
const bool is_empty0 = cells[id].is_empty();
const bool is_empty1 = cell.is_empty();

cells[ubatch.head + i].pos = -1;
}
if (!is_empty0 && is_empty1) {
used--;
} else if (is_empty0 && !is_empty1) {
used++;
}
}

if (new_head != size && new_head < head) {
head = new_head;
cells[id] = cell;
}

pending.clear();
recovery.clear();
}

void llama_kv_cache_unified::commit() {
if (pending.ubatches.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
if (recovery.cells.empty()) {
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
return;
}

pending.clear();
recovery.clear();
}

bool llama_kv_cache_unified::update(llama_context & lctx) {
Expand Down Expand Up @@ -460,16 +447,11 @@ void llama_kv_cache_unified::set_full() {
head = 0;
}

llama_sbatch llama_kv_cache_unified::sbatch_init(
const llama_batch & batch,
bool logits_all) {
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
}

llama_ubatch llama_kv_cache_unified::ubatch_next(
llama_sbatch & sbatch,
uint32_t n_ubatch,
bool embd_pooled) const {
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
GGML_UNUSED(embd_pooled);
return sbatch.split_simple(n_ubatch);
}
Expand All @@ -490,6 +472,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
return false;
}

//#define FIND_SLOT_DEBUG 1
#if FIND_SLOT_DEBUG
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);

// for debugging
{
std::string ss;
if (n_swa > 0) {
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].pos == -1) {
ss += '.';
} else {
ss += std::to_string(*cells[i].seq_id.begin());
}
if (i%256 == 255) {
ss += '\n';
}
}
}
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
}
#endif

uint32_t n_tested = 0;

while (true) {
Expand Down Expand Up @@ -520,6 +525,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
}

for (uint32_t i = 0; i < n_tokens; ++i) {
// remember the original state
if (recovery.cells.find(head + i) == recovery.cells.end()) {
recovery.cells[head + i] = cells[head + i];
}

cells[head + i].pos = ubatch.pos[i];

for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
Expand All @@ -529,14 +539,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {

used += n_tokens;

pending.ubatches.push_back({ head, ubatch });

// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));

//printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
#endif

return true;
}
Expand Down Expand Up @@ -642,6 +652,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
return ggml_cpy(ctx, v_cur, v_view);
}

void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
// no pruning is needed when the cache does not use SWA
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");

for (uint32_t i = 0; i < size; ++i) {
const llama_pos p0 = cells[i].pos;

if (is_masked_swa(p0, p1)) {
if (seq_id < 0) {
cells[i].seq_id.clear();
} else if (cells[i].has_seq_id(seq_id)) {
cells[i].seq_id.erase(seq_id);
} else {
continue;
}

if (cells[i].is_empty()) {
// keep count of the number of used cells
if (cells[i].pos >= 0) {
used--;
}

cells[i].pos = -1;
}
}
}
}

void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Expand Down Expand Up @@ -1181,6 +1219,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
}

bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
if (p0 < 0) {
return true;
}

switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
Expand Down Expand Up @@ -1659,20 +1701,12 @@ void llama_kv_cache_unified_iswa::commit() {
kv_base->commit();
kv_swa ->commit();

if (pending.pos_max.empty()) {
return;
}

// slide the attention window, forgetting/pruning old tokens that are outside the window
for (const auto & [seq_id, pos_max] : pending.pos_max) {
if (pos_max <= (llama_pos) hparams.n_swa) {
continue;
}

kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
kv_swa->prune_swa(seq_id, pos_max);
}

pending.pos_max.clear();
pending.clear();
}

bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
Expand All @@ -1695,12 +1729,18 @@ void llama_kv_cache_unified_iswa::set_full() {
}

llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
pending.pos_max.clear();

for (int i = 0; i < batch.n_tokens; ++i) {
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
const llama_seq_id seq_id = batch.seq_id[i][s];
const llama_pos pos = batch.pos[i];

pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
if (pending.pos_max.find(seq_id) == pending.pos_max.end()) {
pending.pos_max[seq_id] = pos;
} else {
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
}
}
}

Expand Down
40 changes: 22 additions & 18 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

#include "llama.h"
#include "llama-io.h"
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-memory.h"

#include "ggml-cpp.h"

#include <map>
#include <set>
#include <unordered_map>
#include <vector>

struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;

Expand All @@ -40,6 +41,9 @@ struct llama_kv_cache : public llama_memory_i {
// batch processing
//

// =============================================================================================================
// TODO: refactor and simplify this

virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;

// different KV caches require different batch splitting strategies
Expand All @@ -48,6 +52,8 @@ struct llama_kv_cache : public llama_memory_i {
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;

// =============================================================================================================

// getters
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
Expand Down Expand Up @@ -171,6 +177,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;

void prune_swa(llama_seq_id seq_id, llama_pos p1);

void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
Expand Down Expand Up @@ -214,7 +222,7 @@ class llama_kv_cache_unified : public llama_kv_cache {

uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)

// computed before each graph build
uint32_t n = 0;
Expand All @@ -233,27 +241,20 @@ class llama_kv_cache_unified : public llama_kv_cache {
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;

std::vector<kv_cell> cells;
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
std::vector<kv_layer> layers;

// model layer id -> KV cache layer id
std::map<int32_t, int32_t> map_layer_ids;

struct ubatch_info {
uint32_t head;

llama_ubatch data;
};
std::unordered_map<int32_t, int32_t> map_layer_ids;

// pending cell updates that are not yet committed
// recovery information used to restore the KV cells to their original state in case of a failure
struct {
void clear() {
ubatches.clear();
cells.clear();
}

// upon batch processing failure, we revert these ubatches from the KV cells
std::vector<ubatch_info> ubatches;
} pending;
std::unordered_map<uint32_t, kv_cell> cells;
} recovery;

// defrag
struct {
Expand Down Expand Up @@ -377,9 +378,12 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
private:
const llama_hparams & hparams;

// pending cell updates that are not yet committed
struct {
std::map<llama_seq_id, llama_pos> pos_max;
void clear() {
pos_max.clear();
}

std::unordered_map<llama_seq_id, llama_pos> pos_max;
} pending;

std::unique_ptr<llama_kv_cache_unified> kv_base;
Expand Down