Skip to content

Commit 9438c70

Browse files
committed
cont : keep cells meta info in a map [no ci]
1 parent 1e10743 commit 9438c70

File tree

2 files changed

+27
-28
lines changed

2 files changed

+27
-28
lines changed

src/llama-kv-cache.cpp

+25-25
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6868
return it->second;
6969
};
7070

71-
cells_arr[KV_CELLS_TYPE_BASE].reset(new kv_cells(kv_size));
71+
cells_map[KV_CELLS_TYPE_BASE].reset(new kv_cells(kv_size));
7272

7373
layers.resize(n_layer);
7474

@@ -116,7 +116,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
116116
ggml_format_name(k, "cache_k_l%d", i);
117117
ggml_format_name(v, "cache_v_l%d", i);
118118

119-
layer.cells = cells_arr[KV_CELLS_TYPE_BASE].get();
119+
layer.cells = cells_map.at(KV_CELLS_TYPE_BASE).get();
120120

121121
layer.k = k;
122122
layer.v = v;
@@ -168,7 +168,7 @@ void llama_kv_cache_unified::kv_cells::clear() {
168168
}
169169

170170
void llama_kv_cache_unified::clear() {
171-
for (auto & cells : cells_arr) {
171+
for (auto & [_, cells] : cells_map) {
172172
if (!cells) {
173173
continue;
174174
}
@@ -227,7 +227,7 @@ bool llama_kv_cache_unified::kv_cells::seq_rm(llama_seq_id seq_id, llama_pos p0,
227227
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
228228
bool res = true;
229229

230-
for (auto & cells : cells_arr) {
230+
for (auto & [_, cells] : cells_map) {
231231
if (!cells) {
232232
continue;
233233
}
@@ -262,7 +262,7 @@ void llama_kv_cache_unified::kv_cells::seq_cp(llama_seq_id seq_id_src, llama_seq
262262
}
263263

264264
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
265-
for (auto & cells : cells_arr) {
265+
for (auto & [_, cells] : cells_map) {
266266
if (!cells) {
267267
continue;
268268
}
@@ -299,7 +299,7 @@ void llama_kv_cache_unified::kv_cells::seq_keep(llama_seq_id seq_id) {
299299
}
300300

301301
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
302-
for (auto & cells : cells_arr) {
302+
for (auto & [_, cells] : cells_map) {
303303
if (!cells) {
304304
continue;
305305
}
@@ -358,7 +358,7 @@ bool llama_kv_cache_unified::kv_cells::seq_add(llama_seq_id seq_id, llama_pos p0
358358
}
359359

360360
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
361-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
361+
auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
362362

363363
has_shift = cells->seq_add(seq_id, p0, p1, delta);
364364
}
@@ -399,7 +399,7 @@ bool llama_kv_cache_unified::kv_cells::seq_div(llama_seq_id seq_id, llama_pos p0
399399
}
400400

401401
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
402-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
402+
auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
403403

404404
has_shift = cells->seq_div(seq_id, p0, p1, d);
405405
}
@@ -417,7 +417,7 @@ llama_pos llama_kv_cache_unified::kv_cells::seq_pos_max(llama_seq_id seq_id) con
417417
}
418418

419419
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
420-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
420+
auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
421421

422422
return cells->seq_pos_max(seq_id);
423423
}
@@ -450,7 +450,7 @@ void llama_kv_cache_unified::kv_cells::restore() {
450450
}
451451

452452
void llama_kv_cache_unified::restore() {
453-
for (auto & cells : cells_arr) {
453+
for (auto & [_, cells] : cells_map) {
454454
if (!cells) {
455455
continue;
456456
}
@@ -470,7 +470,7 @@ void llama_kv_cache_unified::kv_cells::commit() {
470470
}
471471

472472
void llama_kv_cache_unified::commit() {
473-
for (auto & cells : cells_arr) {
473+
for (auto & [_, cells] : cells_map) {
474474
if (!cells) {
475475
continue;
476476
}
@@ -509,7 +509,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
509509
}
510510

511511
{
512-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
512+
auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
513513

514514
has_shift = false;
515515

@@ -545,7 +545,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
545545
}
546546

547547
void llama_kv_cache_unified::defrag_sched(float thold) {
548-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
548+
auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
549549

550550
// - do not defrag small contexts (i.e. < 2048 tokens)
551551
// - count the padding towards the number of used tokens
@@ -560,7 +560,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
560560
}
561561

562562
void llama_kv_cache_unified::set_full() {
563-
for (auto & cells : cells_arr) {
563+
for (auto & [_, cells] : cells_map) {
564564
if (!cells) {
565565
continue;
566566
}
@@ -653,7 +653,7 @@ bool llama_kv_cache_unified::kv_cells::find_slot(const llama_ubatch & ubatch, ui
653653
bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
654654
bool res = true;
655655

656-
for (auto & cells : cells_arr) {
656+
for (auto & [it, cells] : cells_map) {
657657
if (!cells) {
658658
continue;
659659
}
@@ -665,7 +665,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
665665
}
666666

667667
int32_t llama_kv_cache_unified::get_n_tokens() const {
668-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
668+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
669669

670670
int32_t result = 0;
671671

@@ -677,7 +677,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
677677
}
678678

679679
int32_t llama_kv_cache_unified::get_used_cells() const {
680-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
680+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
681681

682682
return cells->used;
683683
}
@@ -691,12 +691,12 @@ const llama_kv_cache_unified::kv_layer & llama_kv_cache_unified::get_layer(int32
691691
}
692692

693693
uint32_t llama_kv_cache_unified::n_base() const {
694-
return cells_arr[KV_CELLS_TYPE_BASE]->n;
694+
return cells_map.at(KV_CELLS_TYPE_BASE)->n;
695695
}
696696

697697
uint32_t llama_kv_cache_unified::n_swa() const {
698698
#pragma messages("FIX MEEEEEEEEEEEEEEEEEE")
699-
return cells_arr[KV_CELLS_TYPE_BASE]->n;
699+
return cells_map.at(KV_CELLS_TYPE_BASE)->n;
700700
}
701701

702702
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
@@ -707,7 +707,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
707707
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
708708
float * data = (float *) dst->data;
709709

710-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
710+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
711711

712712
const int64_t n_kv = cells->n;
713713

@@ -772,7 +772,7 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
772772
float * data_swa = (float *) dst->data;
773773

774774
#pragma messages("FIX MEEEEEEEEEEEEEEEEEE")
775-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
775+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
776776

777777
const int64_t n_kv = cells->n;
778778

@@ -831,7 +831,7 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
831831
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
832832
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
833833

834-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
834+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
835835

836836
int32_t * data = (int32_t *) dst->data;
837837

@@ -848,7 +848,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
848848

849849
int32_t * data = (int32_t *) dst->data;
850850

851-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
851+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
852852

853853
const int64_t n_kv = cells->n;
854854

@@ -862,7 +862,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
862862
}
863863

864864
llama_pos llama_kv_cache_unified::get_pos_max() const {
865-
const auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
865+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
866866

867867
llama_pos pos_max = -1;
868868

@@ -1166,7 +1166,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
11661166
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
11671167
const uint32_t n_layer = hparams.n_layer;
11681168

1169-
auto & cells = cells_arr[KV_CELLS_TYPE_BASE];
1169+
const auto & cells = cells_map.at(KV_CELLS_TYPE_BASE);
11701170

11711171
const uint32_t n_kv = cells->cell_max();
11721172
const uint32_t n_used = cells->used;

src/llama-kv-cache.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "ggml-cpp.h"
99

10+
#include <map>
1011
#include <set>
1112
#include <vector>
1213

@@ -237,14 +238,12 @@ class llama_kv_cache_unified : public llama_kv_cache {
237238
enum kv_cells_type {
238239
KV_CELLS_TYPE_BASE = 0,
239240
KV_CELLS_TYPE_SWA,
240-
KV_CELLS_TYPE_COUNT,
241241
};
242242

243-
std::array<std::unique_ptr<kv_cells>, KV_CELLS_TYPE_COUNT> cells_arr;
243+
std::map<kv_cells_type, std::unique_ptr<kv_cells>> cells_map;
244244

245245
std::vector<kv_layer> layers;
246246

247-
private:
248247
const llama_model & model;
249248
const llama_hparams & hparams;
250249

0 commit comments

Comments
 (0)