Skip to content

Commit 4feadaa

Browse files
committed
cont : prepare for alternative approach
ggml-ci
1 parent 5a80cbc commit 4feadaa

File tree

3 files changed

+402
-574
lines changed

3 files changed

+402
-574
lines changed

src/llama-graph.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
10231023

10241024
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10251025

1026-
const auto n_kv = kv_self->n_base();
1026+
const auto n_kv = kv_self->n;
10271027

10281028
auto & cur = inp->pos_bucket;
10291029

@@ -1240,7 +1240,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12401240
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12411241

12421242
{
1243-
const auto n_kv = kv_self->n_base();
1243+
const auto n_kv = kv_self->n;
12441244

12451245
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12461246
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1252,7 +1252,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12521252
if (hparams.n_swa_pattern > 1) {
12531253
GGML_ASSERT(hparams.n_swa > 0);
12541254

1255-
const auto n_kv = kv_self->n_swa();
1255+
const auto n_kv = kv_self->n;
12561256

12571257
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12581258
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1297,9 +1297,9 @@ ggml_tensor * llm_graph_context::build_attn(
12971297

12981298
// store to KV cache
12991299
{
1300-
const auto kv_head = kv_layer.cells->head;
1300+
const auto kv_head = kv_self->head;
13011301

1302-
GGML_ASSERT(kv_layer.cells->size == n_ctx);
1302+
GGML_ASSERT(kv_self->size == n_ctx);
13031303

13041304
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_layer.k, n_tokens*n_embd_k_gqa, ggml_row_size(kv_layer.k->type, n_embd_k_gqa)*kv_head);
13051305
//cb(k_cache_view, "k_cache_view", il);
@@ -1331,7 +1331,7 @@ ggml_tensor * llm_graph_context::build_attn(
13311331

13321332
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
13331333

1334-
const auto n_kv = kv_layer.cells->n;
1334+
const auto n_kv = kv_self->n;
13351335

13361336
const auto & n_embd_head_k = hparams.n_embd_head_k;
13371337
const auto & n_embd_head_v = hparams.n_embd_head_v;

0 commit comments

Comments
 (0)