Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2951,7 +2951,8 @@ struct server_context {
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);

if (slot.params.cache_prompt) {
// add generated tokens to cache
{
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
Expand Down Expand Up @@ -2996,10 +2997,7 @@ struct server_context {
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);

slot.n_past += 1;

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
slot.cache_tokens.push_back(slot.sampled);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
Expand Down Expand Up @@ -3171,6 +3169,11 @@ struct server_context {

SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
}
} else {
// if we don't cache the prompt, we have to remove the entire KV cache
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
slot.n_past = 0;
slot.cache_tokens.clear();
}
}

Expand Down Expand Up @@ -3204,7 +3207,7 @@ struct server_context {
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);

// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
slot.cache_tokens.keep_first(slot.n_past);

// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
Expand All @@ -3221,7 +3224,8 @@ struct server_context {
continue;
}

if (slot.params.cache_prompt) {
// add the image chunk to cache
{
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
slot.cache_tokens.push_back(chunk.get()); // copy
}
Expand All @@ -3242,9 +3246,7 @@ struct server_context {
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;

common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(cur_tok);
}
slot.cache_tokens.push_back(cur_tok);

slot.n_prompt_tokens_processed++;
slot.n_past++;
Expand Down
12 changes: 12 additions & 0 deletions tools/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
assert res_cache.body["content"] == res_no_cache.body["content"]


def test_nocache_long_input_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is"*32,
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200


def test_completion_with_tokens_input():
global server
server.temperature = 0.0
Expand Down
2 changes: 1 addition & 1 deletion tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ struct server_tokens {
tokens.clear();
}

void resize(size_t n) {
void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size());
if (has_mtmd) {
// we throw an error if we try to remove a token in the middle of an image
Expand Down
Loading