|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.nn as nn
|
13 |
| -from executorch.examples.models.llama.attention import KVCache |
| 13 | +from executorch.examples.models.llama.attention import ( |
| 14 | + _create_causal_mask_for_ring_buffer, |
| 15 | + CachePositionsManager, |
| 16 | + KVCache, |
| 17 | + RingKVCache, |
| 18 | +) |
14 | 19 |
|
15 | 20 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
16 | 21 |
|
@@ -75,6 +80,7 @@ def __init__(
|
75 | 80 | self.register_buffer(
|
76 | 81 | "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
|
77 | 82 | )
|
| 83 | + self.cache_type = cache_type |
78 | 84 |
|
79 | 85 | def _quantize(self, value):
|
80 | 86 | (
|
@@ -209,6 +215,7 @@ def update(self, input_pos, k_val, v_val, indices=None):
|
209 | 215 | However the storage is [B, S, H, D] so we incur transpose in, transpose out
|
210 | 216 | This shall be removed by subsequent post-export graph pass
|
211 | 217 | """
|
| 218 | + |
212 | 219 | k_val = k_val.transpose(1, 2)
|
213 | 220 | v_val = v_val.transpose(1, 2)
|
214 | 221 |
|
@@ -382,3 +389,185 @@ def _replace_kv_cache_with_custom_kv_cache(module):
|
382 | 389 | else:
|
383 | 390 | _replace_kv_cache_with_custom_kv_cache(child)
|
384 | 391 | return module
|
| 392 | + |
| 393 | + |
| 394 | +class QuantizedRingKVCache(QuantizedKVCache): |
| 395 | + def __init__( |
| 396 | + self, |
| 397 | + max_batch_size, |
| 398 | + max_context_length, |
| 399 | + n_heads, |
| 400 | + head_dim, |
| 401 | + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, |
| 402 | + use_custom_update_cache_op: bool = False, |
| 403 | + ): |
| 404 | + # Look at attention.py for explanation on why max_context_length * 2 |
| 405 | + super().__init__( |
| 406 | + max_batch_size, |
| 407 | + max_context_length * 2, |
| 408 | + n_heads, |
| 409 | + head_dim, |
| 410 | + cache_type, |
| 411 | + use_custom_update_cache_op, |
| 412 | + ) |
| 413 | + self.cache_positions_manager = CachePositionsManager(self.max_context_length) |
| 414 | + self.is_ring_buffer = True |
| 415 | + self.window_size = max_context_length |
| 416 | + |
| 417 | + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): |
| 418 | + cache_positions = self.cache_positions_manager.cache_positions |
| 419 | + return _create_causal_mask_for_ring_buffer( |
| 420 | + cache_positions, self.window_size, start_pos, seq_len |
| 421 | + ) |
| 422 | + |
| 423 | + def update(self, input_pos, k_val, v_val): |
| 424 | + """ |
| 425 | + k_val, v_val: [B, H, S, D] |
| 426 | + return: [B, H, S, D] |
| 427 | + However the storage is [B, S, H, D] so we incur transpose in, transpose out |
| 428 | + This shall be removed by subsequent post-export graph pass |
| 429 | + """ |
| 430 | + # Need to transpose for two reasons |
| 431 | + # 1. kv cache is stored as [B, S, H, D] |
| 432 | + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize |
| 433 | + # away transpose at the output of k, v projection |
| 434 | + seq_len = k_val.transpose(1, 2).size(1) |
| 435 | + assert seq_len <= self.k_cache.size( |
| 436 | + 1 |
| 437 | + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" |
| 438 | + indices = self.cache_positions_manager.calculate_positions_and_update_indices( |
| 439 | + input_pos, seq_len |
| 440 | + ) |
| 441 | + indices = indices.unsqueeze(0) |
| 442 | + |
| 443 | + return super().update(input_pos, k_val, v_val, indices) |
| 444 | + |
| 445 | + @classmethod |
| 446 | + def from_quantized_kv_cache( |
| 447 | + cls, |
| 448 | + kv_cache, |
| 449 | + sliding_window_size, |
| 450 | + ): |
| 451 | + assert isinstance( |
| 452 | + kv_cache, QuantizedKVCache |
| 453 | + ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" |
| 454 | + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape |
| 455 | + return cls( |
| 456 | + max_batch_size, |
| 457 | + sliding_window_size, |
| 458 | + n_heads, |
| 459 | + head_dim, |
| 460 | + kv_cache.cache_type, |
| 461 | + kv_cache.use_custom_update_cache_op, |
| 462 | + ) |
| 463 | + |
| 464 | + |
| 465 | +class CustomRingKVCache(CustomKVCache): |
| 466 | + def __init__( |
| 467 | + self, |
| 468 | + max_batch_size, |
| 469 | + max_context_length, |
| 470 | + n_heads, |
| 471 | + head_dim, |
| 472 | + dtype=torch.float32, |
| 473 | + ): |
| 474 | + # Look at attention.py for explanation on why max_context_length * 2 |
| 475 | + super().__init__( |
| 476 | + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype |
| 477 | + ) |
| 478 | + self.cache_positions_manager = CachePositionsManager(self.max_context_length) |
| 479 | + self.is_ring_buffer = True |
| 480 | + self.window_size = max_context_length |
| 481 | + |
| 482 | + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): |
| 483 | + cache_positions = self.cache_positions_manager.cache_positions |
| 484 | + return _create_causal_mask_for_ring_buffer( |
| 485 | + cache_positions, self.window_size, start_pos, seq_len |
| 486 | + ) |
| 487 | + |
| 488 | + def update(self, input_pos, k_val, v_val): |
| 489 | + """ |
| 490 | + k_val, v_val: [B, H, S, D] |
| 491 | + return: [B, H, S, D] |
| 492 | + However the storage is [B, S, H, D] so we incur transpose in, transpose out |
| 493 | + This shall be removed by subsequent post-export graph pass |
| 494 | + """ |
| 495 | + # Need to transpose for two reasons |
| 496 | + # 1. kv cache is stored as [B, S, H, D] |
| 497 | + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize |
| 498 | + # away transpose at the output of k, v projection |
| 499 | + seq_len = k_val.transpose(1, 2).size(1) |
| 500 | + assert seq_len <= self.k_cache.size( |
| 501 | + 1 |
| 502 | + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" |
| 503 | + indices = self.cache_positions_manager.calculate_positions_and_update_indices( |
| 504 | + input_pos, seq_len |
| 505 | + ) |
| 506 | + indices = indices.unsqueeze(0) |
| 507 | + |
| 508 | + return super().update(input_pos, k_val, v_val, indices) |
| 509 | + |
| 510 | + @classmethod |
| 511 | + def from_custom_kv_cache( |
| 512 | + cls, |
| 513 | + kv_cache, |
| 514 | + sliding_window_size, |
| 515 | + ): |
| 516 | + max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape |
| 517 | + if isinstance(kv_cache, CustomKVCache): |
| 518 | + # If replacing custom kv cache, then the shape is [B, S, H, D] |
| 519 | + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape |
| 520 | + return cls( |
| 521 | + max_batch_size, |
| 522 | + sliding_window_size, |
| 523 | + n_heads, |
| 524 | + head_dim, |
| 525 | + dtype=kv_cache.k_cache.dtype, |
| 526 | + ) |
| 527 | + |
| 528 | + |
| 529 | +def _replace_kv_cache_with_ring_kv_cache(attention, layer_size): |
| 530 | + sliding_window_size = layer_size |
| 531 | + assert ( |
| 532 | + getattr(attention, "kv_cache", None) is not None |
| 533 | + ), "Attention module must have kv_cache module" |
| 534 | + kv_cache = attention.kv_cache |
| 535 | + if isinstance(kv_cache, KVCache): |
| 536 | + attention.kv_cache = RingKVCache( |
| 537 | + kv_cache.max_batch_size, |
| 538 | + sliding_window_size, |
| 539 | + kv_cache.n_heads, |
| 540 | + kv_cache.head_dim, |
| 541 | + kv_cache.enable_dynamic_shape, |
| 542 | + kv_cache.k_cache.dtype, |
| 543 | + ) |
| 544 | + elif isinstance(kv_cache, CustomKVCache): |
| 545 | + attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( |
| 546 | + kv_cache, layer_size |
| 547 | + ) |
| 548 | + elif isinstance(kv_cache, QuantizedKVCache): |
| 549 | + attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( |
| 550 | + kv_cache, layer_size |
| 551 | + ) |
| 552 | + |
| 553 | + |
| 554 | +def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): |
| 555 | + # This is needed to ensure that custom ops are registered |
| 556 | + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 |
| 557 | + |
| 558 | + logging.info( |
| 559 | + "Replacing kv cache with ring kv cache. This modifies the model in place." |
| 560 | + ) |
| 561 | + assert len(layer_sizes) == len( |
| 562 | + module.layers |
| 563 | + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." |
| 564 | + for i, transformer_block in enumerate(module.layers): |
| 565 | + sliding_window_size = layer_sizes[i] |
| 566 | + if sliding_window_size == 0: |
| 567 | + continue |
| 568 | + assert ( |
| 569 | + getattr(transformer_block, "attention", None) is not None |
| 570 | + ), f"Transfomer block must have attention module. Transformer block {transformer_block}" |
| 571 | + attention = transformer_block.attention |
| 572 | + _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) |
| 573 | + return module |
0 commit comments