Skip to content

server : fix cache_tokens bug with no cache_prompt #13533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2951,7 +2951,8 @@ struct server_context {
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);

if (slot.params.cache_prompt) {
// add generated tokens to cache
{
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
Expand Down Expand Up @@ -2996,10 +2997,7 @@ struct server_context {
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);

slot.n_past += 1;

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
slot.cache_tokens.push_back(slot.sampled);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
Expand Down Expand Up @@ -3171,6 +3169,11 @@ struct server_context {

SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
}
} else {
// if we don't cache the prompt, we have to remove the entire KV cache
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
slot.n_past = 0;
slot.cache_tokens.clear();
}
}

Expand Down Expand Up @@ -3204,7 +3207,7 @@ struct server_context {
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);

// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
slot.cache_tokens.keep_first(slot.n_past);

// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
Expand All @@ -3221,7 +3224,8 @@ struct server_context {
continue;
}

if (slot.params.cache_prompt) {
// add the image chunk to cache
{
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
slot.cache_tokens.push_back(chunk.get()); // copy
}
Expand All @@ -3242,9 +3246,7 @@ struct server_context {
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(cur_tok);
}
slot.cache_tokens.push_back(cur_tok);

slot.n_prompt_tokens_processed++;
slot.n_past++;
Expand Down
12 changes: 12 additions & 0 deletions tools/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
assert res_cache.body["content"] == res_no_cache.body["content"]


def test_nocache_long_input_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is"*32,
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200


def test_completion_with_tokens_input():
global server
server.temperature = 0.0
Expand Down
2 changes: 1 addition & 1 deletion tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ struct server_tokens {
tokens.clear();
}

void resize(size_t n) {
void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size());
if (has_mtmd) {
// we throw an error if we try to remove a token in the middle of an image
Expand Down
Loading