@@ -1663,7 +1663,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
1663
1663
1664
1664
#if defined(GGML_DEBUG_SAMPLER_INFILL)
1665
1665
for (size_t i = 0 ; i < cur_p->size ; ++i) {
1666
- LLAMA_LOG_DEBUG (" infill: cur_p[%zu ] = { id: %d , p: %f , logit: %f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1666
+ LLAMA_LOG_DEBUG (" infill: cur_p[%3zu ] = { id: %6d , p: %.6f , logit: %6.3f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1667
1667
}
1668
1668
#endif
1669
1669
@@ -1673,14 +1673,16 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
1673
1673
1674
1674
for (size_t i = 0 ; i < cur_p->size ; ++i) {
1675
1675
p_max = fmaxf (p_max, cur_p->data [i].p );
1676
+
1676
1677
if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1677
1678
p_eog_sum += cur_p->data [i].p ;
1678
1679
} else {
1679
1680
p_txt_sum += cur_p->data [i].p ;
1680
1681
}
1681
1682
}
1682
1683
1683
- const float rat = p_txt_sum / p_eog_sum;
1684
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
1685
+
1684
1686
LLAMA_LOG_DEBUG (" infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_max, p_txt_sum, p_eog_sum, rat, cur_p->size );
1685
1687
1686
1688
if (p_max < 0 .90f && p_eog_sum*cur_p->size > p_txt_sum) {
@@ -1712,48 +1714,50 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
1712
1714
}
1713
1715
1714
1716
if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1715
- if (cur_p->data [i].p > cur_p->data [j].p ) {
1717
+ if (cur_p->data [i].p > cur_p->data [j].p ) {
1716
1718
cur_p->data [i].p += cur_p->data [j].p ;
1717
1719
cur_p->data [j].logit = -INFINITY;
1720
+ cur_p->data [j].p = 0 .0f ;
1718
1721
} else {
1719
1722
cur_p->data [j].p += cur_p->data [i].p ;
1720
1723
cur_p->data [i].logit = -INFINITY;
1724
+ cur_p->data [i].p = 0 .0f ;
1721
1725
}
1722
1726
}
1723
1727
}
1724
1728
}
1725
1729
1726
- // mask non-EOG tokens with prob < 0.2
1727
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1730
+ const auto size_org = cur_p->size ;
1731
+
1732
+ cur_p->size = 0 ;
1733
+
1734
+ float p_sum = 0 .0f ;
1735
+
1736
+ for (size_t i = 0 ; i < size_org; ++i) {
1737
+ // discard non-EOG tokens with prob < 0.2
1728
1738
if (cur_p->data [i].p < 0.2 && !llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1729
- cur_p-> data [i]. logit = -INFINITY ;
1739
+ continue ;
1730
1740
}
1731
- }
1732
1741
1733
- // determine the token with max logit
1734
- float l_max = -INFINITY;
1735
- int i_max = -1 ;
1736
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1737
- if (cur_p->data [i].logit > l_max) {
1738
- l_max = cur_p->data [i].logit ;
1739
- i_max = i;
1740
- }
1742
+ // keep this token
1743
+ p_sum += cur_p->data [i].p ;
1744
+
1745
+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1741
1746
}
1742
1747
1743
1748
// if all probs are -INFINITY -> reduce cur_p to single EOG token
1744
- if (i_max == - 1 ) {
1749
+ if (cur_p-> size == 0 ) {
1745
1750
cur_p->size = 1 ;
1746
1751
cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
1747
1752
cur_p->data [0 ].logit = 1 .0f ;
1748
1753
1749
1754
return ;
1750
1755
}
1751
1756
1752
- // pick the best token
1753
- cur_p->size = 1 ;
1754
- cur_p->data [0 ] = cur_p->data [i_max];
1755
-
1757
+ // normalize probs
1756
1758
for (size_t i = 0 ; i < cur_p->size ; ++i) {
1759
+ cur_p->data [i].p /= p_sum;
1760
+
1757
1761
LLAMA_LOG_DEBUG (" after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1758
1762
}
1759
1763
}
0 commit comments