Skip to content

Commit c7d8904

Browse files
committed
llama : improve infill sampler
ggml-ci
1 parent 8343eeb commit c7d8904

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

include/llama.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,9 +1152,9 @@ extern "C" {
11521152
const llama_logit_bias * logit_bias);
11531153

11541154
// this sampler is meant to be used for fill-in-the-middle infilling
1155-
// it's supposed to be used after top_k sampling and will leave a single candidate token
1155+
// it's supposed to be used after top_k sampling
11561156
//
1157-
// 1. if there is a high-prob token (>= 0.9f) -> pick it
1157+
// 1. if there is a high-prob token (>= 0.9f) -> skip step 2
11581158
// 2. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
11591159
// 3. combine probs of tokens that have the same prefix
11601160
//
@@ -1170,7 +1170,8 @@ extern "C" {
11701170
// "hel": 0.8
11711171
// "dummy": 0.1
11721172
//
1173-
// 4. pick the token with the highest probability
1173+
// 4. discard non-EOG tokens with low prob (< 0.2)
1174+
// 5. if no tokens are left -> pick EOT
11741175
//
11751176
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
11761177

src/llama-sampling.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16631663

16641664
#if defined(GGML_DEBUG_SAMPLER_INFILL)
16651665
for (size_t i = 0; i < cur_p->size; ++i) {
1666-
LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1666+
LLAMA_LOG_DEBUG("infill: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
16671667
}
16681668
#endif
16691669

@@ -1673,14 +1673,16 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16731673

16741674
for (size_t i = 0; i < cur_p->size; ++i) {
16751675
p_max = fmaxf(p_max, cur_p->data[i].p);
1676+
16761677
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
16771678
p_eog_sum += cur_p->data[i].p;
16781679
} else {
16791680
p_txt_sum += cur_p->data[i].p;
16801681
}
16811682
}
16821683

1683-
const float rat = p_txt_sum / p_eog_sum;
1684+
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
1685+
16841686
LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);
16851687

16861688
if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
@@ -1712,48 +1714,50 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17121714
}
17131715

17141716
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
1715-
if (cur_p->data[i].p > cur_p->data[j].p) {
1717+
if (cur_p->data[i].p > cur_p->data[j].p) {
17161718
cur_p->data[i].p += cur_p->data[j].p;
17171719
cur_p->data[j].logit = -INFINITY;
1720+
cur_p->data[j].p = 0.0f;
17181721
} else {
17191722
cur_p->data[j].p += cur_p->data[i].p;
17201723
cur_p->data[i].logit = -INFINITY;
1724+
cur_p->data[i].p = 0.0f;
17211725
}
17221726
}
17231727
}
17241728
}
17251729

1726-
// mask non-EOG tokens with prob < 0.2
1727-
for (size_t i = 0; i < cur_p->size; ++i) {
1730+
const auto size_org = cur_p->size;
1731+
1732+
cur_p->size = 0;
1733+
1734+
float p_sum = 0.0f;
1735+
1736+
for (size_t i = 0; i < size_org; ++i) {
1737+
// discard non-EOG tokens with prob < 0.2
17281738
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1729-
cur_p->data[i].logit = -INFINITY;
1739+
continue;
17301740
}
1731-
}
17321741

1733-
// determine the token with max logit
1734-
float l_max = -INFINITY;
1735-
int i_max = -1;
1736-
for (size_t i = 0; i < cur_p->size; ++i) {
1737-
if (cur_p->data[i].logit > l_max) {
1738-
l_max = cur_p->data[i].logit;
1739-
i_max = i;
1740-
}
1742+
// keep this token
1743+
p_sum += cur_p->data[i].p;
1744+
1745+
cur_p->data[cur_p->size++] = cur_p->data[i];
17411746
}
17421747

17431748
// if all probs are -INFINITY -> reduce cur_p to single EOG token
1744-
if (i_max == -1) {
1749+
if (cur_p->size == 0) {
17451750
cur_p->size = 1;
17461751
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
17471752
cur_p->data[0].logit = 1.0f;
17481753

17491754
return;
17501755
}
17511756

1752-
// pick the best token
1753-
cur_p->size = 1;
1754-
cur_p->data[0] = cur_p->data[i_max];
1755-
1757+
// normalize probs
17561758
for (size_t i = 0; i < cur_p->size; ++i) {
1759+
cur_p->data[i].p /= p_sum;
1760+
17571761
LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
17581762
}
17591763
}

0 commit comments

Comments
 (0)