Skip to content

llama : greatly reduce output buffer memory usage #6122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fd1918
llama : greatly reduce logits memory usage
compilade Mar 15, 2024
98914c0
llama : more compact state saving and reloading
compilade Mar 15, 2024
705d393
llama : fix lctx.n_outputs not being set before building graph
compilade Mar 16, 2024
25981fc
perplexity : adapt to the logits API changes
compilade Mar 17, 2024
17b45c9
perplexity : fix Winogrande, use correct logits for second choice start
compilade Mar 17, 2024
d0129e8
perplexity : normalize spaces and punctuation in Winogrande sentences
compilade Mar 17, 2024
487f89e
llama : fix embedding conditions
compilade Mar 17, 2024
408fcb0
llama : fix llama_get_embeddings_ith when the resulting id is 0
compilade Mar 17, 2024
e19cb3a
llama : fix wrong n_outputs in llama_set_inputs
compilade Mar 17, 2024
a57fa7f
llama : fix not-skipping outputs of non-causal models
compilade Mar 18, 2024
711b0bc
llama : fix running a batch with n_outputs == 0
compilade Mar 18, 2024
d100502
llama : keep same graph topology even when n_outputs == 0
compilade Mar 18, 2024
99c37cc
ggml : saner ggml_can_repeat with empty tensors
compilade Mar 18, 2024
6bf7f3f
ggml : do not multi-thread ops returning empty tensors
compilade Mar 18, 2024
09bb15a
ggml : make ggml_is_empty public and work with views
compilade Mar 19, 2024
4551e7e
llama : use a vector for ctx->output_ids
compilade Mar 19, 2024
8b826c5
ggml : skip empty tensors in all backends
compilade Mar 19, 2024
d04cfaf
llama : fix llama_output_reserve nullptr deref when new_size is 0
compilade Mar 19, 2024
8f70dcb
perplexity : make Winogrande work as it does on master
compilade Mar 19, 2024
615a3a4
llama : clearer error messages for invalid logits or embeddings ids
compilade Mar 19, 2024
7d8d6b5
llama : handle errors from llama_output_reserve at call sites
compilade Mar 21, 2024
5f33a67
perplexity : make hellaswag and multiple-choice outputs identical to …
compilade Mar 21, 2024
ffa9abd
Merge branch 'master' into compilade/smaller-output-buffer
compilade Mar 25, 2024
e9095ac
llama : allow loading state saved with a different ctx size
compilade Mar 26, 2024
5027d81
llama : minor
ggerganov Mar 26, 2024
20248e8
readme : update recent API changes, and warn about Vulkan
compilade Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : more compact state saving and reloading
  • Loading branch information
compilade committed Mar 17, 2024
commit 98914c0ed02f1503762712bbe58bfacfcbf48b60
171 changes: 124 additions & 47 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2102,8 +2102,8 @@ struct llama_context {
float * logits = nullptr;

int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffer
int32_t n_outputs = 0; // number of actually-used outputs in the previous batch
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch

bool logits_all = false;

Expand Down Expand Up @@ -9192,15 +9192,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
GGML_ASSERT(0 <= n_outputs);

const int32_t n_outputs_max = std::max((uint32_t) n_outputs, lctx.cparams.n_seq_max);
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;

const int32_t n_outputs_max = std::max((uint32_t) n_outputs, cparams.n_seq_max);

const auto n_batch = lctx.cparams.n_batch;
const auto n_vocab = lctx.model.hparams.n_vocab;
const auto n_embd = lctx.model.hparams.n_embd;
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_embd = hparams.n_embd;
const int64_t capacity = lctx.output_size;

const bool has_logits = lctx.cparams.causal_attn;
const bool has_embd = lctx.cparams.embeddings;
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

if (!lctx.output_ids) {
// never resized afterwards
Expand All @@ -9211,29 +9214,32 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
}
// alloc only when more than the current logits capacity is required
if (capacity < n_outputs_max) {
lctx.output_size = n_outputs_max;
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;

const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);

if (lctx.buf_output) {
#ifndef NDEBUG
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0);
#endif
ggml_backend_buffer_free(lctx.buf_output);
lctx.buf_output = nullptr;
lctx.logits = nullptr;
lctx.embd = nullptr;
}
{
lctx.output_size = n_outputs_max;
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;

const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);

lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
if (lctx.buf_output == nullptr) {
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
}
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
if (lctx.buf_output == nullptr) {
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
}

float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);

lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
}
lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
}
// set all ids as invalid (assume two's complement negative numbers)
memset(lctx.output_ids, -1, n_batch*sizeof(int32_t));
Expand Down Expand Up @@ -14038,27 +14044,32 @@ void llama_kv_cache_update(struct llama_context * ctx) {

// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
const auto & cparams = ctx->cparams;
const auto & hparams = ctx->model.hparams;
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
const size_t s_n_outputs = sizeof(size_t);
// assume worst case for outputs although only currently set ones are serialized
const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
const size_t s_logits_size = sizeof(size_t);
// assume worst case for logits although only currently set ones are serialized
const size_t s_logits = ctx->logits_size * sizeof(float);
const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embd_size * sizeof(float);
const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
// TODO: assume the max is more than 1 seq_id per KV cell
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;

const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_n_outputs
+ s_output_pos
+ s_logits_size
+ s_logits
+ s_embedding_size
Expand Down Expand Up @@ -14142,25 +14153,60 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(rng_str.data(), rng_size);
}

// copy logits
// copy outputs
{
const size_t logits_size = ctx->logits_size;
size_t n_outputs = ctx->n_outputs;

data_ctx->write(&logits_size, sizeof(logits_size));
// copy output ids
{
std::vector<int32_t> output_pos;
const size_t n_batch = ctx->cparams.n_batch;
const int32_t * output_ids = ctx->output_ids;

output_pos.resize(n_outputs);

// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch; ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
if (pos >= 0) {
if ((size_t) pos >= output_pos.size()) {
// TODO: maybe fail here instead
LLAMA_LOG_WARN("%s: weird output buffer layout, possibly a bug\n", __func__);
n_outputs = pos + 1;
output_pos.resize(n_outputs);
}
output_pos[pos] = i;
}
}

if (logits_size) {
data_ctx->write(ctx->logits, logits_size * sizeof(float));
data_ctx->write(&n_outputs, sizeof(n_outputs));

if (n_outputs) {
data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
}
}
}

// copy embeddings
{
const size_t embeddings_size = ctx->embd_size;
// copy logits
{
const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);

data_ctx->write(&logits_size, sizeof(logits_size));

data_ctx->write(&embeddings_size, sizeof(embeddings_size));
if (logits_size) {
data_ctx->write(ctx->logits, logits_size * sizeof(float));
}
}

if (embeddings_size) {
data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
// copy embeddings
{
const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);

data_ctx->write(&embeddings_size, sizeof(embeddings_size));

if (embeddings_size) {
data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
}
}
}

Expand Down Expand Up @@ -14257,6 +14303,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(!rng_ss.fail());
}

// set output ids
{
size_t n_outputs;
std::vector<int32_t> output_pos;

memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);

llama_output_reserve(*ctx, n_outputs);

if (n_outputs) {
output_pos.resize(n_outputs);
memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
inp += n_outputs * sizeof(int32_t);

for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
ctx->output_ids[id] = i;
}
}
}

// set logits
{
size_t logits_size;
Expand All @@ -14277,7 +14345,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {

memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);

GGML_ASSERT(ctx->embd_size == embeddings_size);
GGML_ASSERT(ctx->embd_size >= embeddings_size);

if (embeddings_size) {
memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
Expand Down Expand Up @@ -14562,20 +14630,24 @@ void llama_synchronize(struct llama_context * ctx) {
}

float * llama_get_logits(struct llama_context * ctx) {
// TODO: assert that really all logits are in the output
llama_synchronize(ctx);

return ctx->logits;
}

float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
GGML_ASSERT(0 <= j);

llama_synchronize(ctx);

// FIXME: check for nullptr
return ctx->logits + j*ctx->model.hparams.n_vocab;
if (ctx->logits && 0 <= j && j < ctx->n_outputs) {
return ctx->logits + j*ctx->model.hparams.n_vocab;
}
LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i);
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}

float * llama_get_embeddings(struct llama_context * ctx) {
Expand All @@ -14586,12 +14658,17 @@ float * llama_get_embeddings(struct llama_context * ctx) {

float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
GGML_ASSERT(0 <= j);

llama_synchronize(ctx);

// FIXME: check for nullptr
return ctx->embd + j*ctx->model.hparams.n_embd;
if (ctx->embd && 0 < j && j < ctx->n_outputs) {
return ctx->embd + j*ctx->model.hparams.n_embd;
}
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}

float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
Expand Down
24 changes: 14 additions & 10 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'

#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 4
#define LLAMA_SESSION_VERSION 5

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -674,25 +674,29 @@ extern "C" {
LLAMA_API void llama_synchronize(struct llama_context * ctx);

// Token logits obtained from the last call to llama_decode()
// WARNING: the following layout is only valid when the batch outputs logits for all tokens
// The logits for the last token are stored in the last row
// Logits for which llama_batch.logits[i] == 0 are undefined
// Rows: n_tokens provided with llama_batch
// The logits for which llama_batch.logits[i] != 0 are stored contiguously
// in the order they have in the batch.
// Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);

// Logits for the ith token. Equivalent to:
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);

// Get all output token embeddings
// WARNING: only use when all outputs are requested
// shape: [n_tokens*n_embd] (1-dimensional)
// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
// in the order they have in the batch.
// shape: [n_outputs*n_embd]
// Otherwise, returns NULL.
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

// Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
// Get the embeddings for the ith token. Equivalent to:
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
// shape: [n_embd] (1-dimensional)
// returns NULL for invalid ids.
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);

// Get the embeddings for a sequence id
Expand Down