Skip to content

Commit 92c0b38

Browse files
authored
grammar : fix integer overflow (#17381)
* Fix DoS / integer overflow * Remove optional, use INT64_MAX instead as placeholder value (it's technically -1, so it fits :) * White space * Actually, since it's unsigned, use UINT64_MAX
1 parent 2286a36 commit 92c0b38

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/llama-grammar.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
#include <cmath>
88
#include <algorithm>
9+
#include <cstdint>
910
#include <stdexcept>
1011

12+
#define MAX_REPETITION_THRESHOLD 2000
1113
//
1214
// helpers
1315
//
@@ -345,7 +347,9 @@ const char * llama_grammar_parser::parse_sequence(
345347
size_t last_sym_start = rule.size();
346348
const char * pos = src;
347349

348-
auto handle_repetitions = [&](int min_times, int max_times) {
350+
// use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
351+
// (though it's technically the same as -1 now)
352+
auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
349353

350354
if (last_sym_start == rule.size()) {
351355
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
@@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
373377
rule.resize(last_sym_start);
374378
} else {
375379
// Repeat the previous elements (min_times - 1) times
376-
for (int i = 1; i < min_times; i++) {
380+
for (unsigned long i = 1; i < min_times; i++) {
377381
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
378382
}
379383
}
380384

381385
uint32_t last_rec_rule_id = 0;
382-
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
386+
auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
383387

384388
llama_grammar_rule rec_rule(prev_rule);
385-
for (int i = 0; i < n_opt; i++) {
389+
for (unsigned long i = 0; i < n_opt; i++) {
386390
rec_rule.resize(prev_rule.size());
387391
uint32_t rec_rule_id = generate_symbol_id( rule_name);
388-
if (i > 0 || max_times < 0) {
389-
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
392+
if (i > 0 || max_times == UINT64_MAX) {
393+
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
390394
}
391395
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
392396
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
@@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
478482
throw std::runtime_error(std::string("expecting an int at ") + pos);
479483
}
480484
const char * int_end = parse_int(pos);
481-
int min_times = std::stoul(std::string(pos, int_end - pos));
485+
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
482486
pos = parse_space(int_end, is_nested);
483487

484-
int max_times = -1;
488+
unsigned long max_times = UINT64_MAX;
485489

486490
if (*pos == '}') {
487491
max_times = min_times;
@@ -502,6 +506,9 @@ const char * llama_grammar_parser::parse_sequence(
502506
} else {
503507
throw std::runtime_error(std::string("expecting ',' at ") + pos);
504508
}
509+
if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
510+
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
511+
}
505512
handle_repetitions(min_times, max_times);
506513
} else {
507514
break;

0 commit comments

Comments
 (0)