diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7169ffdceebf9..a9b99d437e2fd 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2951,7 +2951,8 @@ struct server_context { llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); - if (slot.params.cache_prompt) { + // add generated tokens to cache + { llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; @@ -2996,10 +2997,7 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); - } + slot.cache_tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); @@ -3171,6 +3169,11 @@ struct server_context { SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } + } else { + // if we don't cache the prompt, we have to remove the entire KV cache + llama_kv_self_seq_rm(ctx, slot.id, 0, -1); + slot.n_past = 0; + slot.cache_tokens.clear(); } } @@ -3204,7 +3207,7 @@ struct server_context { SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); + slot.cache_tokens.keep_first(slot.n_past); // check if we should process the image if (slot.n_past < slot.n_prompt_tokens @@ -3221,7 +3224,8 @@ struct server_context { continue; } - if (slot.params.cache_prompt) { + // add the image chunk to cache + { const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); slot.cache_tokens.push_back(chunk.get()); // copy } @@ -3242,9 +3246,7 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(cur_tok); - } + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 0ed5b99bef4e4..4099c4e25cd6e 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt(): assert res_cache.body["content"] == res_no_cache.body["content"] +def test_nocache_long_input_prompt(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is"*32, + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res.status_code == 200 + + def test_completion_with_tokens_input(): global server server.temperature = 0.0 diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b8d140e3f051c..45193c17cfd98 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1153,7 +1153,7 @@ struct server_tokens { tokens.clear(); } - void resize(size_t n) { + void keep_first(size_t n) { GGML_ASSERT(n <= tokens.size()); if (has_mtmd) { // we throw an error if we try to remove a token in the middle of an image