@@ -3369,25 +3369,18 @@ static struct ggml_cgraph * llm_build_baichaun(
33693369
33703370static struct ggml_cgraph * llm_build_refact (
33713371 llama_context & lctx,
3372- const llama_token * tokens,
3373- const float * embd,
3374- int n_tokens,
3375- int n_past) {
3376-
3377- GGML_ASSERT ((!tokens && embd) || (tokens && !embd)); // NOLINT
3378-
3379- const int N = n_tokens;
3380-
3372+ const llama_batch & batch) {
33813373 const auto & model = lctx.model ;
33823374 const auto & hparams = model.hparams ;
3375+ const auto & cparams = lctx.cparams ;
33833376
33843377 const auto & kv_self = lctx.kv_self ;
33853378
33863379 GGML_ASSERT (!!kv_self.ctx );
33873380
33883381 const int64_t n_embd = hparams.n_embd ;
33893382 const int64_t n_layer = hparams.n_layer ;
3390- const int64_t n_ctx = hparams .n_ctx ;
3383+ const int64_t n_ctx = cparams .n_ctx ;
33913384 const int64_t n_head = hparams.n_head ;
33923385 const int64_t n_head_kv = hparams.n_head_kv ;
33933386 const int64_t n_embd_head = hparams.n_embd_head ();
@@ -3397,6 +3390,12 @@ static struct ggml_cgraph * llm_build_refact(
33973390
33983391 const int n_gpu_layers = model.n_gpu_layers ;
33993392
3393+ const int32_t n_tokens = batch.n_tokens ;
3394+ const int32_t n_kv = ggml_allocr_is_measure (lctx.alloc ) ? n_ctx : kv_self.n ;
3395+ const int32_t kv_head = ggml_allocr_is_measure (lctx.alloc ) ? n_ctx - n_tokens : kv_self.head ;
3396+
3397+ // printf("n_kv = %d\n", n_kv);
3398+
34003399 auto & buf_compute = lctx.buf_compute ;
34013400
34023401 struct ggml_init_params params = {
@@ -3414,12 +3413,12 @@ static struct ggml_cgraph * llm_build_refact(
34143413 struct ggml_tensor * cur;
34153414 struct ggml_tensor * inpL;
34163415
3417- if (tokens ) {
3418- struct ggml_tensor * inp_tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, N );
3416+ if (batch. token ) {
3417+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens );
34193418
34203419 ggml_allocr_alloc (lctx.alloc , inp_tokens);
34213420 if (!ggml_allocr_is_measure (lctx.alloc )) {
3422- memcpy (inp_tokens->data , tokens, N *ggml_element_size (inp_tokens));
3421+ memcpy (inp_tokens->data , batch. token , n_tokens *ggml_element_size (inp_tokens));
34233422 }
34243423 ggml_set_name (inp_tokens, " inp_tokens" );
34253424
@@ -3429,11 +3428,11 @@ static struct ggml_cgraph * llm_build_refact(
34293428 GGML_ASSERT (false && " not implemented" );
34303429#endif
34313430
3432- inpL = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N );
3431+ inpL = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_tokens );
34333432
34343433 ggml_allocr_alloc (lctx.alloc , inpL);
34353434 if (!ggml_allocr_is_measure (lctx.alloc )) {
3436- memcpy (inpL->data , embd, N * n_embd * ggml_element_size (inpL));
3435+ memcpy (inpL->data , batch. embd , n_tokens * n_embd * ggml_element_size (inpL));
34373436 }
34383437 }
34393438
@@ -3442,9 +3441,6 @@ static struct ggml_cgraph * llm_build_refact(
34423441
34433442 // offload functions set the tensor output backend to GPU
34443443 // tensors are GPU-accelerated if any input or the output has been offloaded
3445- //
3446- // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
3447- // in that case ggml_cuda_assign_buffers has no effect
34483444 offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
34493445 offload_func_t offload_func_kq = llama_nop;
34503446 offload_func_t offload_func_v = llama_nop;
@@ -3461,12 +3457,36 @@ static struct ggml_cgraph * llm_build_refact(
34613457 }
34623458#endif // GGML_USE_CUBLAS
34633459
3460+ // KQ_scale
34643461 struct ggml_tensor * KQ_scale = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, 1 );
3462+ ggml_set_name (KQ_scale, " 1/sqrt(n_embd_head)" );
34653463 ggml_allocr_alloc (lctx.alloc , KQ_scale);
34663464 if (!ggml_allocr_is_measure (lctx.alloc )) {
3467- ggml_set_f32 (KQ_scale, 1 .0f /sqrtf (float (n_embd)/n_head));
3465+ ggml_set_f32 (KQ_scale, 1 .0f /sqrtf (float (n_embd_head)));
3466+ }
3467+
3468+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
3469+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1 );
3470+ offload_func_kq (KQ_mask);
3471+ ggml_set_name (KQ_mask, " KQ_mask" );
3472+ ggml_allocr_alloc (lctx.alloc , KQ_mask);
3473+ if (!ggml_allocr_is_measure (lctx.alloc )) {
3474+ float * data = (float *) KQ_mask->data ;
3475+ memset (data, 0 , ggml_nbytes (KQ_mask));
3476+
3477+ for (int h = 0 ; h < 1 ; ++h) {
3478+ for (int j = 0 ; j < n_tokens; ++j) {
3479+ const llama_pos pos = batch.pos [j];
3480+ const llama_seq_id seq_id = batch.seq_id [j];
3481+
3482+ for (int i = 0 ; i < n_kv; ++i) {
3483+ if (!kv_self.cells [i].has_seq_id (seq_id) || kv_self.cells [i].pos > pos) {
3484+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3485+ }
3486+ }
3487+ }
3488+ }
34683489 }
3469- ggml_set_name (KQ_scale, " 1/sqrt(n_embd_head)" );
34703490
34713491 for (int il = 0 ; il < n_layer; ++il) {
34723492 ggml_format_name (inpL, " layer_inp_%d" , il);
@@ -3504,36 +3524,33 @@ static struct ggml_cgraph * llm_build_refact(
35043524 offload_func_kq (tmpq);
35053525 ggml_set_name (tmpq, " tmpq" );
35063526
3507- struct ggml_tensor * Kcur;
3508- struct ggml_tensor * Qcur;
3509- Kcur = ggml_reshape_3d (ctx0, tmpk, n_embd_head, n_head_kv, N);
3510- Qcur = ggml_reshape_3d (ctx0, tmpq, n_embd_head, n_head, N);
3511-
3527+ struct ggml_tensor * Kcur = ggml_reshape_3d (ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
35123528 offload_func_kq (Kcur);
35133529 ggml_set_name (Kcur, " Kcur" );
35143530
3531+ struct ggml_tensor * Qcur = ggml_reshape_3d (ctx0, tmpq, n_embd_head, n_head, n_tokens);
35153532 offload_func_kq (Qcur);
35163533 ggml_set_name (Qcur, " Qcur" );
35173534
35183535 // store key and value to memory
35193536 {
3520- // compute the transposed [N , n_embd] V matrix
3537+ // compute the transposed [n_tokens , n_embd] V matrix
35213538
35223539 struct ggml_tensor * tmpv = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
35233540 offload_func_v (tmpv);
35243541 ggml_set_name (tmpv, " tmpv" );
35253542
3526- struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, tmpv, n_embd_gqa, N ));
3543+ struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, tmpv, n_embd_gqa, n_tokens ));
35273544 offload_func_v (Vcur);
35283545 ggml_set_name (Vcur, " Vcur" );
35293546
3530- struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , N *n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + n_past ));
3547+ struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , n_tokens *n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + kv_head ));
35313548 offload_func_kq (k);
35323549 ggml_set_name (k, " k" );
35333550
3534- struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , N , n_embd_gqa,
3551+ struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , n_tokens , n_embd_gqa,
35353552 ( n_ctx)*ggml_element_size (kv_self.v ),
3536- (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + n_past *ggml_element_size (kv_self.v ));
3553+ (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + kv_head *ggml_element_size (kv_self.v ));
35373554 offload_func_v (v);
35383555 ggml_set_name (v, " v" );
35393556
@@ -3547,7 +3564,7 @@ static struct ggml_cgraph * llm_build_refact(
35473564
35483565 struct ggml_tensor * K =
35493566 ggml_view_3d (ctx0, kv_self.k ,
3550- n_embd_head, n_past + N , n_head_kv,
3567+ n_embd_head, n_kv , n_head_kv,
35513568 ggml_element_size (kv_self.k )*n_embd_gqa,
35523569 ggml_element_size (kv_self.k )*n_embd_head,
35533570 ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
@@ -3560,25 +3577,28 @@ static struct ggml_cgraph * llm_build_refact(
35603577 ggml_set_name (KQ, " KQ" );
35613578
35623579 // KQ_scaled = KQ / sqrt(n_embd_head)
3563- // KQ_scaled shape [n_past + N, N , n_head, 1]
3564- struct ggml_tensor * KQ_scaled = ggml_scale_inplace (ctx0, KQ, KQ_scale);
3580+ // KQ_scaled shape [n_kv, n_tokens , n_head, 1]
3581+ struct ggml_tensor * KQ_scaled = ggml_scale (ctx0, KQ, KQ_scale);
35653582 offload_func_kq (KQ_scaled);
35663583 ggml_set_name (KQ_scaled, " KQ_scaled" );
35673584
3568- struct ggml_tensor * KQ_masked;
3569- struct ggml_tensor * KQ_scaled_alibi;
3570-
3571- KQ_scaled_alibi =ggml_alibi (ctx0, KQ_scaled, n_past, n_head, 8 );
3585+ // KQ_masked = mask_past(KQ_scaled)
3586+ struct ggml_tensor * KQ_scaled_alibi = ggml_alibi (ctx0, KQ_scaled, /* n_past*/ 0 , n_head, 8 );
35723587 ggml_set_name (KQ_scaled_alibi, " KQ_scaled_alibi" );
3573- KQ_masked = ggml_diag_mask_inf (ctx0, KQ_scaled_alibi, n_past);
3574- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, KQ_masked);
3588+
3589+ struct ggml_tensor * KQ_masked = ggml_add (ctx0, KQ_scaled_alibi, KQ_mask);
3590+ offload_func_kq (KQ_masked);
3591+ ggml_set_name (KQ_masked, " KQ_masked" );
3592+
3593+ // KQ = soft_max(KQ_masked)
3594+ struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_masked);
35753595 offload_func_v (KQ_soft_max);
35763596 ggml_set_name (KQ_soft_max, " KQ_soft_max" );
35773597
35783598 // split cached V into n_head heads
35793599 struct ggml_tensor * V =
35803600 ggml_view_3d (ctx0, kv_self.v ,
3581- n_past + N , n_embd_head, n_head_kv,
3601+ n_kv , n_embd_head, n_head_kv,
35823602 ggml_element_size (kv_self.v )*n_ctx,
35833603 ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
35843604 ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
@@ -3593,7 +3613,7 @@ static struct ggml_cgraph * llm_build_refact(
35933613 // make V contiguous in memory to speed up the matmul, however we waste time on the copy
35943614 // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
35953615 // is there a better way?
3596- struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N , n_embd_head, n_head));
3616+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx , n_embd_head, n_head));
35973617 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
35983618 #endif
35993619
@@ -3602,10 +3622,8 @@ static struct ggml_cgraph * llm_build_refact(
36023622 offload_func_v (KQV_merged);
36033623 ggml_set_name (KQV_merged, " KQV_merged" );
36043624
3605- // cur = KQV_merged.contiguous().view(n_embd, N)
3606- cur = ggml_cpy (ctx0,
3607- KQV_merged,
3608- ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
3625+ // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3626+ cur = ggml_cont_2d (ctx0, KQV_merged, n_embd, n_tokens);
36093627 offload_func_v (cur);
36103628 ggml_set_name (cur, " KQV_merged_contiguous" );
36113629
@@ -4338,7 +4356,7 @@ static struct ggml_cgraph * llama_build_graph(
43384356 } break ;
43394357 case LLM_ARCH_REFACT:
43404358 {
4341- result = llm_build_refact (lctx, tokens, embd, n_tokens, n_past );
4359+ result = llm_build_refact (lctx, batch );
43424360 } break ;
43434361 default :
43444362 GGML_ASSERT (false );
0 commit comments