@@ -68,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
68
68
return it->second ;
69
69
};
70
70
71
- cells_arr [KV_CELLS_TYPE_BASE].reset (new kv_cells (kv_size));
71
+ cells_map [KV_CELLS_TYPE_BASE].reset (new kv_cells (kv_size));
72
72
73
73
layers.resize (n_layer);
74
74
@@ -116,7 +116,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
116
116
ggml_format_name (k, " cache_k_l%d" , i);
117
117
ggml_format_name (v, " cache_v_l%d" , i);
118
118
119
- layer.cells = cells_arr[ KV_CELLS_TYPE_BASE] .get ();
119
+ layer.cells = cells_map. at ( KV_CELLS_TYPE_BASE) .get ();
120
120
121
121
layer.k = k;
122
122
layer.v = v;
@@ -168,7 +168,7 @@ void llama_kv_cache_unified::kv_cells::clear() {
168
168
}
169
169
170
170
void llama_kv_cache_unified::clear () {
171
- for (auto & cells : cells_arr ) {
171
+ for (auto & [_, cells] : cells_map ) {
172
172
if (!cells) {
173
173
continue ;
174
174
}
@@ -227,7 +227,7 @@ bool llama_kv_cache_unified::kv_cells::seq_rm(llama_seq_id seq_id, llama_pos p0,
227
227
bool llama_kv_cache_unified::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
228
228
bool res = true ;
229
229
230
- for (auto & cells : cells_arr ) {
230
+ for (auto & [_, cells] : cells_map ) {
231
231
if (!cells) {
232
232
continue ;
233
233
}
@@ -262,7 +262,7 @@ void llama_kv_cache_unified::kv_cells::seq_cp(llama_seq_id seq_id_src, llama_seq
262
262
}
263
263
264
264
void llama_kv_cache_unified::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
265
- for (auto & cells : cells_arr ) {
265
+ for (auto & [_, cells] : cells_map ) {
266
266
if (!cells) {
267
267
continue ;
268
268
}
@@ -299,7 +299,7 @@ void llama_kv_cache_unified::kv_cells::seq_keep(llama_seq_id seq_id) {
299
299
}
300
300
301
301
void llama_kv_cache_unified::seq_keep (llama_seq_id seq_id) {
302
- for (auto & cells : cells_arr ) {
302
+ for (auto & [_, cells] : cells_map ) {
303
303
if (!cells) {
304
304
continue ;
305
305
}
@@ -358,7 +358,7 @@ bool llama_kv_cache_unified::kv_cells::seq_add(llama_seq_id seq_id, llama_pos p0
358
358
}
359
359
360
360
void llama_kv_cache_unified::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
361
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
361
+ auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
362
362
363
363
has_shift = cells->seq_add (seq_id, p0, p1, delta);
364
364
}
@@ -399,7 +399,7 @@ bool llama_kv_cache_unified::kv_cells::seq_div(llama_seq_id seq_id, llama_pos p0
399
399
}
400
400
401
401
void llama_kv_cache_unified::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
402
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
402
+ auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
403
403
404
404
has_shift = cells->seq_div (seq_id, p0, p1, d);
405
405
}
@@ -417,7 +417,7 @@ llama_pos llama_kv_cache_unified::kv_cells::seq_pos_max(llama_seq_id seq_id) con
417
417
}
418
418
419
419
llama_pos llama_kv_cache_unified::seq_pos_max (llama_seq_id seq_id) const {
420
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
420
+ auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
421
421
422
422
return cells->seq_pos_max (seq_id);
423
423
}
@@ -450,7 +450,7 @@ void llama_kv_cache_unified::kv_cells::restore() {
450
450
}
451
451
452
452
void llama_kv_cache_unified::restore () {
453
- for (auto & cells : cells_arr ) {
453
+ for (auto & [_, cells] : cells_map ) {
454
454
if (!cells) {
455
455
continue ;
456
456
}
@@ -470,7 +470,7 @@ void llama_kv_cache_unified::kv_cells::commit() {
470
470
}
471
471
472
472
void llama_kv_cache_unified::commit () {
473
- for (auto & cells : cells_arr ) {
473
+ for (auto & [_, cells] : cells_map ) {
474
474
if (!cells) {
475
475
continue ;
476
476
}
@@ -509,7 +509,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
509
509
}
510
510
511
511
{
512
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
512
+ auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
513
513
514
514
has_shift = false ;
515
515
@@ -545,7 +545,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
545
545
}
546
546
547
547
void llama_kv_cache_unified::defrag_sched (float thold) {
548
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
548
+ auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
549
549
550
550
// - do not defrag small contexts (i.e. < 2048 tokens)
551
551
// - count the padding towards the number of used tokens
@@ -560,7 +560,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
560
560
}
561
561
562
562
void llama_kv_cache_unified::set_full () {
563
- for (auto & cells : cells_arr ) {
563
+ for (auto & [_, cells] : cells_map ) {
564
564
if (!cells) {
565
565
continue ;
566
566
}
@@ -653,7 +653,7 @@ bool llama_kv_cache_unified::kv_cells::find_slot(const llama_ubatch & ubatch, ui
653
653
bool llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) {
654
654
bool res = true ;
655
655
656
- for (auto & cells : cells_arr ) {
656
+ for (auto & [it, cells] : cells_map ) {
657
657
if (!cells) {
658
658
continue ;
659
659
}
@@ -665,7 +665,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
665
665
}
666
666
667
667
int32_t llama_kv_cache_unified::get_n_tokens () const {
668
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
668
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
669
669
670
670
int32_t result = 0 ;
671
671
@@ -677,7 +677,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
677
677
}
678
678
679
679
int32_t llama_kv_cache_unified::get_used_cells () const {
680
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
680
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
681
681
682
682
return cells->used ;
683
683
}
@@ -691,12 +691,12 @@ const llama_kv_cache_unified::kv_layer & llama_kv_cache_unified::get_layer(int32
691
691
}
692
692
693
693
uint32_t llama_kv_cache_unified::n_base () const {
694
- return cells_arr[ KV_CELLS_TYPE_BASE] ->n ;
694
+ return cells_map. at ( KV_CELLS_TYPE_BASE) ->n ;
695
695
}
696
696
697
697
uint32_t llama_kv_cache_unified::n_swa () const {
698
698
#pragma messages("FIX MEEEEEEEEEEEEEEEEEE")
699
- return cells_arr[ KV_CELLS_TYPE_BASE] ->n ;
699
+ return cells_map. at ( KV_CELLS_TYPE_BASE) ->n ;
700
700
}
701
701
702
702
void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
@@ -707,7 +707,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
707
707
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
708
708
float * data = (float *) dst->data ;
709
709
710
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
710
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
711
711
712
712
const int64_t n_kv = cells->n ;
713
713
@@ -772,7 +772,7 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
772
772
float * data_swa = (float *) dst->data ;
773
773
774
774
#pragma messages("FIX MEEEEEEEEEEEEEEEEEE")
775
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
775
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
776
776
777
777
const int64_t n_kv = cells->n ;
778
778
@@ -831,7 +831,7 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
831
831
void llama_kv_cache_unified::set_input_k_shift (ggml_tensor * dst) const {
832
832
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
833
833
834
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
834
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
835
835
836
836
int32_t * data = (int32_t *) dst->data ;
837
837
@@ -848,7 +848,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
848
848
849
849
int32_t * data = (int32_t *) dst->data ;
850
850
851
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
851
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
852
852
853
853
const int64_t n_kv = cells->n ;
854
854
@@ -862,7 +862,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
862
862
}
863
863
864
864
llama_pos llama_kv_cache_unified::get_pos_max () const {
865
- const auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
865
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
866
866
867
867
llama_pos pos_max = -1 ;
868
868
@@ -1166,7 +1166,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1166
1166
bool llama_kv_cache_unified::defrag_prepare (int32_t n_max_nodes) {
1167
1167
const uint32_t n_layer = hparams.n_layer ;
1168
1168
1169
- auto & cells = cells_arr[ KV_CELLS_TYPE_BASE] ;
1169
+ const auto & cells = cells_map. at ( KV_CELLS_TYPE_BASE) ;
1170
1170
1171
1171
const uint32_t n_kv = cells->cell_max ();
1172
1172
const uint32_t n_used = cells->used ;
0 commit comments