Skip to content

Commit a1867e0

Browse files
committed
sync : llama.cpp
1 parent e67dfbc commit a1867e0

File tree

122 files changed

+14246
-13306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+14246
-13306
lines changed

examples/talk-llama/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ if (WHISPER_SDL2)
22
set(CMAKE_CXX_STANDARD 17)
33
set(CMAKE_CXX_STANDARD_REQUIRED ON)
44

5+
file(GLOB SRC_MODELS models/*.cpp)
6+
57
set(TARGET whisper-talk-llama)
68
add_executable(${TARGET} talk-llama.cpp
79
llama.cpp
@@ -29,7 +31,8 @@ if (WHISPER_SDL2)
2931
llama-sampling.cpp
3032
llama-vocab.cpp
3133
unicode.cpp
32-
unicode-data.cpp)
34+
unicode-data.cpp
35+
${SRC_MODELS})
3336
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
3437

3538
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})

examples/talk-llama/llama-arch.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3232
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
3333
{ LLM_ARCH_QWEN3, "qwen3" },
3434
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
35+
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
36+
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
3537
{ LLM_ARCH_PHI2, "phi2" },
3638
{ LLM_ARCH_PHI3, "phi3" },
3739
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -103,6 +105,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
103105
{ LLM_ARCH_SEED_OSS, "seed_oss" },
104106
{ LLM_ARCH_GROVEMOE, "grovemoe" },
105107
{ LLM_ARCH_APERTUS, "apertus" },
108+
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
109+
{ LLM_ARCH_COGVLM, "cogvlm" },
110+
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
106111
{ LLM_ARCH_UNKNOWN, "(unknown)" },
107112
};
108113

@@ -145,6 +150,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
145150
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
146151
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
147152
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
153+
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
148154
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
149155
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
150156
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -779,6 +785,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
779785
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
780786
},
781787
},
788+
{
789+
LLM_ARCH_QWEN3VL,
790+
{
791+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
792+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
793+
{ LLM_TENSOR_OUTPUT, "output" },
794+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
795+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
796+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
797+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
798+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
799+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
800+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
801+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
802+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
803+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
804+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
805+
},
806+
},
807+
{
808+
LLM_ARCH_QWEN3VLMOE,
809+
{
810+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
811+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
812+
{ LLM_TENSOR_OUTPUT, "output" },
813+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
814+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
815+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
816+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
817+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
818+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
819+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
820+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
821+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
822+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
823+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
824+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
825+
},
826+
},
782827
{
783828
LLM_ARCH_PHI2,
784829
{
@@ -2312,6 +2357,64 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
23122357
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
23132358
},
23142359
},
2360+
{
2361+
LLM_ARCH_MINIMAX_M2,
2362+
{
2363+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2364+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2365+
{ LLM_TENSOR_OUTPUT, "output" },
2366+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2367+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2368+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2369+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2370+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2371+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2372+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2373+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2374+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2375+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2376+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2377+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2378+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2379+
},
2380+
},
2381+
{
2382+
LLM_ARCH_PANGU_EMBED,
2383+
{
2384+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2385+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2386+
{ LLM_TENSOR_OUTPUT, "output" },
2387+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2388+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2389+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2390+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2391+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2392+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2393+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2394+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2395+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2396+
},
2397+
},
2398+
{
2399+
LLM_ARCH_COGVLM,
2400+
{
2401+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2402+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2403+
{ LLM_TENSOR_OUTPUT, "output" },
2404+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2405+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
2406+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2407+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2408+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2409+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2410+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2411+
{ LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" },
2412+
{ LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" },
2413+
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
2414+
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
2415+
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
2416+
},
2417+
},
23152418
{
23162419
LLM_ARCH_UNKNOWN,
23172420
{
@@ -2488,6 +2591,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
24882591
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
24892592
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24902593
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2594+
{LLM_TENSOR_VISEXP_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2595+
{LLM_TENSOR_VISEXP_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2596+
{LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2597+
{LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2598+
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
24912599
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
24922600
// These tensors only exist in the last layer(s) and are treated as output tensors
24932601
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},

examples/talk-llama/llama-arch.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ enum llm_arch {
3636
LLM_ARCH_QWEN2VL,
3737
LLM_ARCH_QWEN3,
3838
LLM_ARCH_QWEN3MOE,
39+
LLM_ARCH_QWEN3VL,
40+
LLM_ARCH_QWEN3VLMOE,
3941
LLM_ARCH_PHI2,
4042
LLM_ARCH_PHI3,
4143
LLM_ARCH_PHIMOE,
@@ -107,6 +109,9 @@ enum llm_arch {
107109
LLM_ARCH_SEED_OSS,
108110
LLM_ARCH_GROVEMOE,
109111
LLM_ARCH_APERTUS,
112+
LLM_ARCH_MINIMAX_M2,
113+
LLM_ARCH_COGVLM,
114+
LLM_ARCH_PANGU_EMBED,
110115
LLM_ARCH_UNKNOWN,
111116
};
112117

@@ -149,6 +154,7 @@ enum llm_kv {
149154
LLM_KV_EXPERTS_PER_GROUP,
150155
LLM_KV_MOE_EVERY_N_LAYERS,
151156
LLM_KV_NEXTN_PREDICT_LAYERS,
157+
LLM_KV_NUM_DEEPSTACK_LAYERS,
152158
LLM_KV_POOLING_TYPE,
153159
LLM_KV_LOGIT_SCALE,
154160
LLM_KV_DECODER_START_TOKEN_ID,
@@ -455,6 +461,11 @@ enum llm_tensor {
455461
LLM_TENSOR_SHORTCONV_CONV,
456462
LLM_TENSOR_SHORTCONV_INPROJ,
457463
LLM_TENSOR_SHORTCONV_OUTPROJ,
464+
LLM_TENSOR_VISEXP_ATTN_QKV,
465+
LLM_TENSOR_VISEXP_ATTN_OUT,
466+
LLM_TENSOR_VISEXP_FFN_GATE,
467+
LLM_TENSOR_VISEXP_FFN_DOWN,
468+
LLM_TENSOR_VISEXP_FFN_UP,
458469
LLM_TENSOR_NEXTN_EH_PROJ,
459470
LLM_TENSOR_NEXTN_EMBED_TOKENS,
460471
LLM_TENSOR_NEXTN_ENORM,

examples/talk-llama/llama-batch.cpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
215215
/*.n_seq_tokens =*/ (uint32_t) 1,
216216
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
217217
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
218+
/*.n_pos =*/ n_pos_per_embd,
218219
/*.token =*/ batch.token,
219220
/*.embd =*/ batch.embd,
220221
/*.pos =*/ batch.pos,
@@ -251,46 +252,72 @@ bool llama_batch_allocr::init(
251252
// consistency checks
252253
//
253254

254-
for (uint32_t s = 0; s < n_seq_max; ++s) {
255-
if (seq_pos[s].empty()) {
256-
continue;
255+
if (n_pos_per_embd > 1) {
256+
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257+
for (uint32_t s = 0; s < n_seq_max; ++s) {
258+
if (seq_pos[s].empty()) {
259+
continue;
260+
}
261+
262+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263+
264+
if (batch.token) {
265+
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
266+
LLAMA_LOG_ERROR(
267+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
268+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
269+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
270+
" for M-RoPE, it is required that the position satisfies: X < Y\n",
271+
__func__, s, s, p0, s, seq_pos_min(s));
272+
273+
return false;
274+
}
275+
} else {
276+
// embedding inputs can have overlapping positions
277+
if (p0 >= 0 && p0 > seq_pos_min(s)) {
278+
LLAMA_LOG_ERROR(
279+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
280+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
281+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
282+
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
283+
__func__, s, s, p0, s, seq_pos_min(s));
284+
285+
return false;
286+
}
287+
}
257288
}
289+
} else {
290+
for (uint32_t s = 0; s < n_seq_max; ++s) {
291+
if (seq_pos[s].empty()) {
292+
continue;
293+
}
258294

259-
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
295+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260296

261-
if (p0 >= 0) {
262-
bool ok = true;
297+
if (p0 >= 0) {
298+
bool ok = true;
263299

264-
if (batch.token) {
265300
if (seq_pos_min(s) != p0 + 1) {
266301
ok = false;
267302
}
268-
} else {
269-
assert(batch.embd);
270303

271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
304+
if (!ok) {
305+
LLAMA_LOG_ERROR(
306+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
310+
__func__, s, s, p0, s, seq_pos_min(s));
311+
312+
return false;
275313
}
276314
}
277315

278-
if (!ok) {
279-
LLAMA_LOG_ERROR(
280-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
284-
__func__, s, s, p0, s, seq_pos_min(s));
285-
316+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
317+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
286318
return false;
287319
}
288320
}
289-
290-
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291-
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
292-
return false;
293-
}
294321
}
295322

296323
if (memory) {
@@ -389,6 +416,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
389416
/*.n_seq_tokens =*/ n_seq_tokens,
390417
/*.n_seqs =*/ n_seqs,
391418
/*.n_seqs_unq =*/ n_seqs,
419+
/*.n_pos =*/ n_pos_per_embd,
392420

393421
/*.token =*/ udata->token.data(),
394422
/*.embd =*/ nullptr,
@@ -655,10 +683,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655683

656684
auto udata = std::make_shared<llama_ubatch::data_t>();
657685

658-
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
659-
660686
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661-
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
687+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
662688

663689
udata->token .resize(n_tokens);
664690
udata->embd .resize(n_embd_all);
@@ -680,8 +706,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
680706
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
681707
}
682708

683-
for (int j = 0; j < n_pos_cur; ++j) {
684-
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
709+
for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
710+
// if we are using M-RoPE
711+
// if the current batch is text, we need to broadcast the same position across all RoPE sections
712+
// otherwise, the input batch is image embeddings, we copy the positions as-is
713+
// if we are not using M-RoPE, there is only one position per token (this loop runs only once)
714+
size_t src_off = batch.token ? 0 : j*batch.n_tokens;
715+
udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
685716
}
686717

687718
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
@@ -710,6 +741,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
710741
/*.n_seq_tokens =*/ n_tokens/n_seqs,
711742
/*.n_seqs =*/ n_seqs,
712743
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
744+
/*.n_pos =*/ n_pos_per_embd,
713745

714746
/*.token =*/ batch.token ? udata->token.data() : nullptr,
715747
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,

examples/talk-llama/llama-batch.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ struct llama_ubatch {
1717
return b_equal_seqs != 0;
1818
}
1919

20+
// typical for M-RoPE cases:
21+
// 0 - sequantial position of the tokens/embeddings in the sequence
22+
// 1 - y position in the image
23+
// 2 - x position in the image
24+
// 3 - other
25+
bool is_pos_2d() const {
26+
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
27+
return n_pos >= 3;
28+
}
29+
2030
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
2131
// otherwise address sanitizer complains
2232
// TODO: whole_seqs for embeddings?
@@ -25,6 +35,7 @@ struct llama_ubatch {
2535
uint32_t n_seq_tokens; // tokens per sequence set
2636
uint32_t n_seqs; // sequence sets in the ubatch
2737
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
38+
uint32_t n_pos; // number of position inputs for each token/embedding
2839

2940
// seq_id_unq: unique sequence ids in the ubatch
3041
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -33,7 +44,7 @@ struct llama_ubatch {
3344
// // size | idx | val
3445
llama_token * token; // [n_tokens] | i | id, token
3546
float * embd; // [n_embd, n_tokens] | i | embd
36-
llama_pos * pos; // [n_tokens] | i | pos
47+
llama_pos * pos; // [n_tokens*n_pos] | i | pos
3748
int32_t * n_seq_id; // [n_tokens] | i | -
3849
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
3950
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id

0 commit comments

Comments
 (0)