diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 515fd0080fc..63d783c3332 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -150,6 +150,16 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +def _create_causal_mask_for_ring_buffer( + cache_positions, window_size, start_pos, seq_len +): + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + class CacheUpdateStrategy(Enum): RING_BUFFER = "RingBuffer" INVALID = "Invalid" @@ -283,12 +293,10 @@ def __init__( self.is_ring_buffer = True def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): - pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) cache_positions = self.cache_positions_manager.cache_positions - delta = pos_q - cache_positions - attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size) - attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 - return attn_mask + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 4674074f8a5..ffe6732dd53 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -10,7 +10,12 @@ import torch import torch.nn as nn -from executorch.examples.models.llama.attention import KVCache +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + CachePositionsManager, + KVCache, + RingKVCache, +) from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -75,6 +80,7 @@ def __init__( self.register_buffer( "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8) ) + self.cache_type = cache_type def _quantize(self, value): ( @@ -209,6 +215,7 @@ def update(self, input_pos, k_val, v_val, indices=None): However the storage is [B, S, H, D] so we incur transpose in, transpose out This shall be removed by subsequent post-export graph pass """ + k_val = k_val.transpose(1, 2) v_val = v_val.transpose(1, 2) @@ -382,3 +389,185 @@ def _replace_kv_cache_with_custom_kv_cache(module): else: _replace_kv_cache_with_custom_kv_cache(child) return module + + +class QuantizedRingKVCache(QuantizedKVCache): + def __init__( + self, + max_batch_size, + max_context_length, + n_heads, + head_dim, + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, + use_custom_update_cache_op: bool = False, + ): + # Look at attention.py for explanation on why max_context_length * 2 + super().__init__( + max_batch_size, + max_context_length * 2, + n_heads, + head_dim, + cache_type, + use_custom_update_cache_op, + ) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + self.window_size = max_context_length + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + cache_positions = self.cache_positions_manager.cache_positions + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + # Need to transpose for two reasons + # 1. kv cache is stored as [B, S, H, D] + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize + # away transpose at the output of k, v projection + seq_len = k_val.transpose(1, 2).size(1) + assert seq_len <= self.k_cache.size( + 1 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + indices = indices.unsqueeze(0) + + return super().update(input_pos, k_val, v_val, indices) + + @classmethod + def from_quantized_kv_cache( + cls, + kv_cache, + sliding_window_size, + ): + assert isinstance( + kv_cache, QuantizedKVCache + ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + return cls( + max_batch_size, + sliding_window_size, + n_heads, + head_dim, + kv_cache.cache_type, + kv_cache.use_custom_update_cache_op, + ) + + +class CustomRingKVCache(CustomKVCache): + def __init__( + self, + max_batch_size, + max_context_length, + n_heads, + head_dim, + dtype=torch.float32, + ): + # Look at attention.py for explanation on why max_context_length * 2 + super().__init__( + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype + ) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + self.window_size = max_context_length + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + cache_positions = self.cache_positions_manager.cache_positions + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + # Need to transpose for two reasons + # 1. kv cache is stored as [B, S, H, D] + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize + # away transpose at the output of k, v projection + seq_len = k_val.transpose(1, 2).size(1) + assert seq_len <= self.k_cache.size( + 1 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + indices = indices.unsqueeze(0) + + return super().update(input_pos, k_val, v_val, indices) + + @classmethod + def from_custom_kv_cache( + cls, + kv_cache, + sliding_window_size, + ): + max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + return cls( + max_batch_size, + sliding_window_size, + n_heads, + head_dim, + dtype=kv_cache.k_cache.dtype, + ) + + +def _replace_kv_cache_with_ring_kv_cache(attention, layer_size): + sliding_window_size = layer_size + assert ( + getattr(attention, "kv_cache", None) is not None + ), "Attention module must have kv_cache module" + kv_cache = attention.kv_cache + if isinstance(kv_cache, KVCache): + attention.kv_cache = RingKVCache( + kv_cache.max_batch_size, + sliding_window_size, + kv_cache.n_heads, + kv_cache.head_dim, + kv_cache.enable_dynamic_shape, + kv_cache.k_cache.dtype, + ) + elif isinstance(kv_cache, CustomKVCache): + attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( + kv_cache, layer_size + ) + elif isinstance(kv_cache, QuantizedKVCache): + attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( + kv_cache, layer_size + ) + + +def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): + # This is needed to ensure that custom ops are registered + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + + logging.info( + "Replacing kv cache with ring kv cache. This modifies the model in place." + ) + assert len(layer_sizes) == len( + module.layers + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." + for i, transformer_block in enumerate(module.layers): + sliding_window_size = layer_sizes[i] + if sliding_window_size == 0: + continue + assert ( + getattr(transformer_block, "attention", None) is not None + ), f"Transfomer block must have attention module. Transformer block {transformer_block}" + attention = transformer_block.attention + _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) + return module diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 0d52cfa19d3..40ab6653c60 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -55,8 +55,33 @@ python_unittest( srcs = [ "test_ring_attention.py", ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], deps = [ "//caffe2:torch", + "//executorch/examples/models/llama:export_library", + "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama:custom_kv_cache", + "//executorch/examples/models/llama:sdpa", + ], +) + +python_unittest( + name = "test_replace_kv_cache", + srcs = [ + "test_replace_kv_cache.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:export_library", "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama:custom_kv_cache", + "//executorch/examples/models/llama:sdpa", ], ) diff --git a/examples/models/llama/tests/test_replace_kv_cache.py b/examples/models/llama/tests/test_replace_kv_cache.py new file mode 100644 index 00000000000..8d7171633b2 --- /dev/null +++ b/examples/models/llama/tests/test_replace_kv_cache.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List + +import torch.nn as nn + +from executorch.examples.models.llama.attention import ( + Attention, + AttentionMHA, + KVCache, + RingKVCache, + Rope, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + QuantizedKVCache, + QuantizedRingKVCache, + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, + replace_kv_cache_with_ring_kv_cache, +) + + +class MockTransformerBlock(nn.Module): + def __init__(self, attention: Attention): + super().__init__() + self.attention = attention + + +class TestReplaceKVCache(unittest.TestCase): + def setUp(self): + # Common parameters for creating attention modules + self.batch_size = 2 + self.seq_len = 10 + self.dim = 32 + self.n_heads = 4 + self.n_kv_heads = 2 + self.head_dim = 8 + self.max_context_len = 16 + self.enable_dynamic_shape = True + + # Create model args + self.args = ModelArgs( + dim=self.dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + max_batch_size=self.batch_size, + max_context_len=self.max_context_len, + use_kv_cache=True, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + # Create a rope instance + self.rope = Rope(self.args) + + def _create_attention_with_kv_cache(self) -> Attention: + """Create an attention module with KVCache.""" + return AttentionMHA(self.args, layer_id=0, rope=self.rope) + + def _create_mock_model(self, attention_modules: List[Attention]) -> nn.Module: + """Create a mock model with transformer blocks containing the given attention modules.""" + model = nn.Module() + model.layers = nn.ModuleList( + [MockTransformerBlock(attention) for attention in attention_modules] + ) + return model + + def test_replace_kv_cache_with_ring_kv_cache(self): + """Test replacing KVCache with RingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Verify that the model has KVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, KVCache) + self.assertNotIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + + # Replace KVCache with RingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that KVCache has been replaced with RingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + + # Verify that the sliding window size is set correctly + self.assertEqual(model.layers[0].attention.kv_cache.window_size, layer_sizes[0]) + + def test_replace_custom_kv_cache_with_custom_ring_kv_cache(self): + """Test replacing CustomKVCache with CustomRingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Replace KVCache with CustomKVCache + replace_kv_cache_with_custom_kv_cache(model) + + # Verify that the model has CustomKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, CustomKVCache) + self.assertNotIsInstance(model.layers[0].attention.kv_cache, CustomRingKVCache) + + # Replace CustomKVCache with CustomRingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that CustomKVCache has been replaced with CustomRingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, CustomRingKVCache) + + def test_replace_quantized_kv_cache_with_quantized_ring_kv_cache(self): + """Test replacing QuantizedKVCache with QuantizedRingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Replace KVCache with QuantizedKVCache + replace_kv_cache_with_quantized_kv_cache(model) + + # Verify that the model has QuantizedKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, QuantizedKVCache) + self.assertNotIsInstance( + model.layers[0].attention.kv_cache, QuantizedRingKVCache + ) + + # Replace QuantizedKVCache with QuantizedRingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that QuantizedKVCache has been replaced with QuantizedRingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, QuantizedRingKVCache) + + def test_multiple_layers_with_different_window_sizes(self): + """Test replacing KV caches in multiple layers with different window sizes.""" + # Create a model with multiple layers + attention1 = self._create_attention_with_kv_cache() + attention2 = self._create_attention_with_kv_cache() + attention3 = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention1, attention2, attention3]) + + # Replace KVCache with RingKVCache with different window sizes + layer_sizes = [4, 8, 16] # Different sliding window sizes for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that each layer has the correct window size + self.assertIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[0].attention.kv_cache.window_size, layer_sizes[0]) + + self.assertIsInstance(model.layers[1].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[1].attention.kv_cache.window_size, layer_sizes[1]) + + self.assertIsInstance(model.layers[2].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[2].attention.kv_cache.window_size, layer_sizes[2]) diff --git a/examples/models/llama/tests/test_ring_attention.py b/examples/models/llama/tests/test_ring_attention.py index 064be7f04e0..df0d0733033 100644 --- a/examples/models/llama/tests/test_ring_attention.py +++ b/examples/models/llama/tests/test_ring_attention.py @@ -6,14 +6,29 @@ import copy import unittest +from enum import Enum import torch from executorch.examples.models.llama.attention import AttentionMHA, RingKVCache from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + QuantizedKVCache, + QuantizedRingKVCache, + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, +) from torch.nn.attention import SDPBackend +class KVCacheType(Enum): + REGULAR = "regular" + QUANTIZED = "quantized" + CUSTOM = "custom" + + class TestRingAttention(unittest.TestCase): def setUp(self): # Common test parameters @@ -28,7 +43,9 @@ def setUp(self): self.dtype = torch.float32 self.device = "cpu" - def _create_baseline_attention(self, seq_len: int): + def _create_baseline_attention( + self, seq_len: int, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Create baseline attention with regular KV cache.""" # Create model args self.args = ModelArgs( @@ -50,24 +67,54 @@ def _create_baseline_attention(self, seq_len: int): seq_len, self.max_context_len, self.sliding_window ) - return attention - - def _create_ring_attention(self, attention): + # Replace the KV cache with the specified type + if kv_cache_type == KVCacheType.QUANTIZED: + # Create a copy to avoid modifying the original attention + attention_copy = copy.deepcopy(attention) + # Replace KVCache with QuantizedKVCache + replace_kv_cache_with_quantized_kv_cache(attention_copy) + return attention_copy + elif kv_cache_type == KVCacheType.CUSTOM: + # Create a copy to avoid modifying the original attention + attention_copy = copy.deepcopy(attention) + # Replace KVCache with CustomKVCache + replace_kv_cache_with_custom_kv_cache(attention_copy) + return attention_copy + else: + return attention + + def _create_ring_attention( + self, attention, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Create attention with ring buffer KV cache.""" assert self.sliding_window is not None # Create RoPE instance self.rope = Rope(self.args) baseline_attention = copy.deepcopy(attention) - # Replace the KV cache with a ring buffer KV cache - baseline_attention.kv_cache = RingKVCache( - self.args.max_batch_size, - self.sliding_window, - self.n_kv_heads, - self.head_dim, - self.args.enable_dynamic_shape, - self.dtype, - ) + # Replace the KV cache with a ring buffer KV cache based on the type + if isinstance(baseline_attention.kv_cache, QuantizedKVCache): + # Replace QuantizedKVCache with QuantizedRingKVCache + baseline_attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( + baseline_attention.kv_cache, + self.sliding_window, + ) + elif isinstance(baseline_attention.kv_cache, CustomKVCache): + # Replace CustomKVCache with CustomRingKVCache + baseline_attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( + baseline_attention.kv_cache, + self.sliding_window, + ) + else: + # Replace regular KVCache with RingKVCache + baseline_attention.kv_cache = RingKVCache( + self.args.max_batch_size, + self.sliding_window, + self.n_kv_heads, + self.head_dim, + self.args.enable_dynamic_shape, + self.dtype, + ) return baseline_attention def _create_sliding_window_mask(self, seq_len, context_len, window_size): @@ -80,12 +127,20 @@ def _create_sliding_window_mask(self, seq_len, context_len, window_size): mask[i, start_idx : pos + 1] = 0 return mask - def test_single_token_processing(self): + def _run_test_with_kv_cache_type(self, test_func, kv_cache_type: KVCacheType): + """Run a test with the specified KV cache type.""" + original_test_name = test_func.__name__ + print(f"\nRunning {original_test_name} with {kv_cache_type.value} KV cache") + test_func(kv_cache_type) + + def test_single_token_processing( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer and baseline produce the same output for single token processing.""" seq_len = 10 self.sliding_window = 4 - baseline_attn = self._create_baseline_attention(seq_len) - ring_attn = self._create_ring_attention(baseline_attn) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process tokens one by one with torch.nn.attention.sdpa_kernel( @@ -113,17 +168,31 @@ def test_single_token_processing(self): f"Outputs differ at position {pos}", ) - def test_sliding_window_attention(self): + def test_single_token_processing_quantized(self): + """Test single token processing with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_single_token_processing, KVCacheType.QUANTIZED + ) + + def test_single_token_processing_custom(self): + """Test single token processing with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_single_token_processing, KVCacheType.CUSTOM + ) + + def test_sliding_window_attention( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer with sliding window size produces the same output as baseline with sliding window mask.""" self.sliding_window = 4 self.max_context_len = 16 seq_len = 10 # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(seq_len) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process tokens one by one with torch.nn.attention.sdpa_kernel( @@ -150,16 +219,32 @@ def test_sliding_window_attention(self): f"Outputs differ at position {pos}", ) - def test_ring_buffer_wrapping(self): + def test_sliding_window_attention_quantized(self): + """Test sliding window attention with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_sliding_window_attention, KVCacheType.QUANTIZED + ) + + def test_sliding_window_attention_custom(self): + """Test sliding window attention with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_sliding_window_attention, KVCacheType.CUSTOM + ) + + def test_ring_buffer_wrapping( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer correctly wraps around and maintains correct attention patterns.""" self.sliding_window = 3 self.max_context_len = 15 # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(self.max_context_len) + baseline_attn = self._create_baseline_attention( + self.max_context_len, kv_cache_type + ) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process enough tokens to cause wrapping seq_len = 1 @@ -198,7 +283,21 @@ def test_ring_buffer_wrapping(self): f"Expected positions {expected_positions}, got {cache_positions}", ) - def test_large_context_with_sliding_window(self): + def test_ring_buffer_wrapping_quantized(self): + """Test ring buffer wrapping with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_ring_buffer_wrapping, KVCacheType.QUANTIZED + ) + + def test_ring_buffer_wrapping_custom(self): + """Test ring buffer wrapping with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_ring_buffer_wrapping, KVCacheType.CUSTOM + ) + + def test_large_context_with_sliding_window( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test with a large context length and compare baseline with sliding window to ring buffer.""" # Use a larger context length and sliding window for this test self.max_context_len = 64 @@ -207,10 +306,10 @@ def test_large_context_with_sliding_window(self): token_lens = [8, 1, 3, 2, 1, 1, 1, 1, 7, 1, 5, 1, 1, 1, 4, 1, 1, 2, 1, 1] seq_len = sum(token_lens) # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(seq_len) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) pos = 0 with torch.nn.attention.sdpa_kernel( @@ -239,3 +338,15 @@ def test_large_context_with_sliding_window(self): f"Outputs differ at position {pos} with max difference {(baseline_out - ring_out).abs().max()}", ) pos += token_len + + def test_large_context_with_sliding_window_quantized(self): + """Test large context with sliding window with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_large_context_with_sliding_window, KVCacheType.QUANTIZED + ) + + def test_large_context_with_sliding_window_custom(self): + """Test large context with sliding window with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_large_context_with_sliding_window, KVCacheType.CUSTOM + )