Skip to content

Commit bc03bd1

Browse files
authored
[Executorch][llm] Enable leveraging ring kv cache via module swap
Differential Revision: D73891426 Pull Request resolved: #10611
1 parent ec0cfcc commit bc03bd1

File tree

5 files changed

+522
-31
lines changed

5 files changed

+522
-31
lines changed

examples/models/llama/attention.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ def forward(
150150
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
151151

152152

153+
def _create_causal_mask_for_ring_buffer(
154+
cache_positions, window_size, start_pos, seq_len
155+
):
156+
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
157+
delta = pos_q - cache_positions
158+
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size)
159+
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
160+
return attn_mask
161+
162+
153163
class CacheUpdateStrategy(Enum):
154164
RING_BUFFER = "RingBuffer"
155165
INVALID = "Invalid"
@@ -283,12 +293,10 @@ def __init__(
283293
self.is_ring_buffer = True
284294

285295
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
286-
pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1)
287296
cache_positions = self.cache_positions_manager.cache_positions
288-
delta = pos_q - cache_positions
289-
attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size)
290-
attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712
291-
return attn_mask
297+
return _create_causal_mask_for_ring_buffer(
298+
cache_positions, self.window_size, start_pos, seq_len
299+
)
292300

293301
def update(
294302
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor

examples/models/llama/source_transformation/custom_kv_cache.py

+190-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from executorch.examples.models.llama.attention import KVCache
13+
from executorch.examples.models.llama.attention import (
14+
_create_causal_mask_for_ring_buffer,
15+
CachePositionsManager,
16+
KVCache,
17+
RingKVCache,
18+
)
1419

1520
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1621

@@ -75,6 +80,7 @@ def __init__(
7580
self.register_buffer(
7681
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
7782
)
83+
self.cache_type = cache_type
7884

7985
def _quantize(self, value):
8086
(
@@ -209,6 +215,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
209215
However the storage is [B, S, H, D] so we incur transpose in, transpose out
210216
This shall be removed by subsequent post-export graph pass
211217
"""
218+
212219
k_val = k_val.transpose(1, 2)
213220
v_val = v_val.transpose(1, 2)
214221

@@ -382,3 +389,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
382389
else:
383390
_replace_kv_cache_with_custom_kv_cache(child)
384391
return module
392+
393+
394+
class QuantizedRingKVCache(QuantizedKVCache):
395+
def __init__(
396+
self,
397+
max_batch_size,
398+
max_context_length,
399+
n_heads,
400+
head_dim,
401+
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
402+
use_custom_update_cache_op: bool = False,
403+
):
404+
# Look at attention.py for explanation on why max_context_length * 2
405+
super().__init__(
406+
max_batch_size,
407+
max_context_length * 2,
408+
n_heads,
409+
head_dim,
410+
cache_type,
411+
use_custom_update_cache_op,
412+
)
413+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
414+
self.is_ring_buffer = True
415+
self.window_size = max_context_length
416+
417+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
418+
cache_positions = self.cache_positions_manager.cache_positions
419+
return _create_causal_mask_for_ring_buffer(
420+
cache_positions, self.window_size, start_pos, seq_len
421+
)
422+
423+
def update(self, input_pos, k_val, v_val):
424+
"""
425+
k_val, v_val: [B, H, S, D]
426+
return: [B, H, S, D]
427+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
428+
This shall be removed by subsequent post-export graph pass
429+
"""
430+
# Need to transpose for two reasons
431+
# 1. kv cache is stored as [B, S, H, D]
432+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
433+
# away transpose at the output of k, v projection
434+
seq_len = k_val.transpose(1, 2).size(1)
435+
assert seq_len <= self.k_cache.size(
436+
1
437+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
438+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
439+
input_pos, seq_len
440+
)
441+
indices = indices.unsqueeze(0)
442+
443+
return super().update(input_pos, k_val, v_val, indices)
444+
445+
@classmethod
446+
def from_quantized_kv_cache(
447+
cls,
448+
kv_cache,
449+
sliding_window_size,
450+
):
451+
assert isinstance(
452+
kv_cache, QuantizedKVCache
453+
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
454+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
455+
return cls(
456+
max_batch_size,
457+
sliding_window_size,
458+
n_heads,
459+
head_dim,
460+
kv_cache.cache_type,
461+
kv_cache.use_custom_update_cache_op,
462+
)
463+
464+
465+
class CustomRingKVCache(CustomKVCache):
466+
def __init__(
467+
self,
468+
max_batch_size,
469+
max_context_length,
470+
n_heads,
471+
head_dim,
472+
dtype=torch.float32,
473+
):
474+
# Look at attention.py for explanation on why max_context_length * 2
475+
super().__init__(
476+
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
477+
)
478+
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
479+
self.is_ring_buffer = True
480+
self.window_size = max_context_length
481+
482+
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
483+
cache_positions = self.cache_positions_manager.cache_positions
484+
return _create_causal_mask_for_ring_buffer(
485+
cache_positions, self.window_size, start_pos, seq_len
486+
)
487+
488+
def update(self, input_pos, k_val, v_val):
489+
"""
490+
k_val, v_val: [B, H, S, D]
491+
return: [B, H, S, D]
492+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
493+
This shall be removed by subsequent post-export graph pass
494+
"""
495+
# Need to transpose for two reasons
496+
# 1. kv cache is stored as [B, S, H, D]
497+
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
498+
# away transpose at the output of k, v projection
499+
seq_len = k_val.transpose(1, 2).size(1)
500+
assert seq_len <= self.k_cache.size(
501+
1
502+
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
503+
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
504+
input_pos, seq_len
505+
)
506+
indices = indices.unsqueeze(0)
507+
508+
return super().update(input_pos, k_val, v_val, indices)
509+
510+
@classmethod
511+
def from_custom_kv_cache(
512+
cls,
513+
kv_cache,
514+
sliding_window_size,
515+
):
516+
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
517+
if isinstance(kv_cache, CustomKVCache):
518+
# If replacing custom kv cache, then the shape is [B, S, H, D]
519+
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
520+
return cls(
521+
max_batch_size,
522+
sliding_window_size,
523+
n_heads,
524+
head_dim,
525+
dtype=kv_cache.k_cache.dtype,
526+
)
527+
528+
529+
def _replace_kv_cache_with_ring_kv_cache(attention, layer_size):
530+
sliding_window_size = layer_size
531+
assert (
532+
getattr(attention, "kv_cache", None) is not None
533+
), "Attention module must have kv_cache module"
534+
kv_cache = attention.kv_cache
535+
if isinstance(kv_cache, KVCache):
536+
attention.kv_cache = RingKVCache(
537+
kv_cache.max_batch_size,
538+
sliding_window_size,
539+
kv_cache.n_heads,
540+
kv_cache.head_dim,
541+
kv_cache.enable_dynamic_shape,
542+
kv_cache.k_cache.dtype,
543+
)
544+
elif isinstance(kv_cache, CustomKVCache):
545+
attention.kv_cache = CustomRingKVCache.from_custom_kv_cache(
546+
kv_cache, layer_size
547+
)
548+
elif isinstance(kv_cache, QuantizedKVCache):
549+
attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache(
550+
kv_cache, layer_size
551+
)
552+
553+
554+
def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
555+
# This is needed to ensure that custom ops are registered
556+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
557+
558+
logging.info(
559+
"Replacing kv cache with ring kv cache. This modifies the model in place."
560+
)
561+
assert len(layer_sizes) == len(
562+
module.layers
563+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
564+
for i, transformer_block in enumerate(module.layers):
565+
sliding_window_size = layer_sizes[i]
566+
if sliding_window_size == 0:
567+
continue
568+
assert (
569+
getattr(transformer_block, "attention", None) is not None
570+
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
571+
attention = transformer_block.attention
572+
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
573+
return module

examples/models/llama/tests/TARGETS

+25
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,33 @@ python_unittest(
5555
srcs = [
5656
"test_ring_attention.py",
5757
],
58+
preload_deps = [
59+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
60+
"//executorch/kernels/quantized:aot_lib",
61+
],
5862
deps = [
5963
"//caffe2:torch",
64+
"//executorch/examples/models/llama:export_library",
65+
"//executorch/examples/models/llama:llama_transformer",
66+
"//executorch/examples/models/llama:custom_kv_cache",
67+
"//executorch/examples/models/llama:sdpa",
68+
],
69+
)
70+
71+
python_unittest(
72+
name = "test_replace_kv_cache",
73+
srcs = [
74+
"test_replace_kv_cache.py",
75+
],
76+
preload_deps = [
77+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
78+
"//executorch/kernels/quantized:aot_lib",
79+
],
80+
deps = [
81+
"//caffe2:torch",
82+
"//executorch/examples/models/llama:export_library",
6083
"//executorch/examples/models/llama:llama_transformer",
84+
"//executorch/examples/models/llama:custom_kv_cache",
85+
"//executorch/examples/models/llama:sdpa",
6186
],
6287
)

0 commit comments

Comments
 (0)