Skip to content

bench : handle decode errors #13548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions tools/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
}
};

static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
if (res != 0) {
fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
return false;
}
n_processed += n_tokens;
}

llama_synchronize(ctx);
return true;
}

static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;

for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
if (res != 0) {
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
return false;
}
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
return true;
}

static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
Expand Down Expand Up @@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
exit(1);
}
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, t.n_threads);
bool res = test_gen(ctx, 1, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
exit(1);
}
}

for (int i = 0; i < params.reps; i++) {
Expand All @@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
exit(1);
}
}

uint64_t t_start = get_time_ns();
Expand All @@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
exit(1);
}
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_threads);
bool res = test_gen(ctx, t.n_gen, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run gen\n", __func__);
exit(1);
}
}

uint64_t t_ns = get_time_ns() - t_start;
Expand Down
Loading