diff --git a/.azure/gpu-test.yml b/.azure/gpu-test.yml index 26409cf958..25ca5436c7 100644 --- a/.azure/gpu-test.yml +++ b/.azure/gpu-test.yml @@ -34,6 +34,7 @@ jobs: CUDA_VERSION: "12.6.3" TORCH_VERSION: "2.7.1" CUDNN_FRONTEND_VERSION: "1.10.0" + CUBLAS_WORKSPACE_CONFIG: ":4096:8" container: # image: "pytorchlightning/pytorch_lightning:base-cuda-py$(PYTHON_VERSION)-torch$(TORCH_VERSION)-cuda$(CUDA_VERSION)" # pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.5.0-py3.10-pt_main-dev diff --git a/extensions/xla/finetune/adapter.py b/extensions/xla/finetune/adapter.py index 051baea75f..af27f4c1cd 100644 --- a/extensions/xla/finetune/adapter.py +++ b/extensions/xla/finetune/adapter.py @@ -241,9 +241,9 @@ def validate( encoded = tokenizer.encode(prompt, device=fabric.device) with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8) - model.clear_kv_cache() + model.clear_kv_caches() output = tokenizer.decode(output) rank_print(fabric, output) diff --git a/extensions/xla/generate/adapter.py b/extensions/xla/generate/adapter.py index e8358349ed..32559bf65d 100644 --- a/extensions/xla/generate/adapter.py +++ b/extensions/xla/generate/adapter.py @@ -102,7 +102,7 @@ def main( # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) t0 = time.perf_counter() y = generate( diff --git a/extensions/xla/generate/base.py b/extensions/xla/generate/base.py index fa09c80c9a..eea9205b80 100644 --- a/extensions/xla/generate/base.py +++ b/extensions/xla/generate/base.py @@ -164,7 +164,7 @@ def main( for i in range(num_samples): with fabric.init_tensor(): # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) t0 = time.perf_counter() y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 5297df4eb3..084b93008a 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -15,7 +15,9 @@ import torch.nn as nn from typing_extensions import Self +from litgpt.attention import DefaultKeysAndValues, MultiHeadSelfAttention from litgpt.config import Config as BaseConfig +from litgpt.kvcache.base import KVCache from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention @@ -29,12 +31,16 @@ class Config(BaseConfig): class GPT(BaseModel): # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, **mha_kwargs) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -42,8 +48,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config, **mha_kwargs) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_caches`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -57,18 +66,33 @@ def _init_weights(self, module: nn.Module) -> None: class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.model.CausalSelfAttention` that adds the attention over the adaption prompt.""" - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - if block_idx >= config.adapter_start_layer: + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__( + config=config, + block_idx=block_idx, + kv_cache=kv_cache, + ) + self._extend_forward = block_idx >= config.adapter_start_layer + if self._extend_forward: # adapter embedding layer self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) # gate for adaption @@ -76,37 +100,45 @@ def __init__(self, config: Config, block_idx: int) -> None: # kv cache for inference self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - def scaled_dot_product_attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + def _transform_output( + self, + y: torch.Tensor, + query: torch.Tensor, + mha: MultiHeadSelfAttention, ) -> torch.Tensor: - y = super().scaled_dot_product_attention(q, k, v, mask) - if self.block_idx < self.config.adapter_start_layer: - return y - - aT = self.config.adapter_prompt_length - if self.adapter_kv_cache is not None: - # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av - # are the same every call - ak, av = self.adapter_kv_cache - else: - prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) - aqkv = self.qkv(prefix) - q_per_kv = self.config.n_head // self.config.n_query_groups - aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) - aqkv = aqkv.permute(0, 2, 3, 1, 4) - _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) - if self.config.n_query_groups != 1: - # for MHA this is a no-op - ak = ak.repeat_interleave(q_per_kv, dim=2) - av = av.repeat_interleave(q_per_kv, dim=2) - ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) - av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) - self.adapter_kv_cache = (ak, av) - - T = q.size(2) - amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) - ay = super().scaled_dot_product_attention(q, ak, av, amask) - return y + self.gating_factor * ay + if self._extend_forward: + B, T, _ = y.shape + y = y.view(B, T, self.config.n_head, self.config.head_size) + aT = self.config.adapter_prompt_length + if self.adapter_kv_cache is not None: + # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av + # are the same every call + ak, av = self.adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) + aqkv = self.qkv(prefix) + q_per_kv = self.config.n_head // self.config.n_query_groups + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + self.adapter_kv_cache = (ak, av) + + amask = torch.ones(T, aT, dtype=torch.bool, device=query.device) + a_k_and_v = DefaultKeysAndValues(keys=ak, values=av) + ay, _ = mha.scaled_dot_product_attention( + query=query, + k_and_v=a_k_and_v, + mask=amask, + ) + y = (y + self.gating_factor * ay).view(B, T, -1) + + return y def reset_parameters(self) -> None: if hasattr(self, "gating_factor"): diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index fb4d12ae08..b8caea8c19 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -19,6 +19,8 @@ from litgpt.adapter import GPT as BaseModel from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig +from litgpt.attention import MultiHeadSelfAttention +from litgpt.kvcache.base import KVCache from litgpt.model import Block as BaseBlock from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -69,7 +71,11 @@ def __init__(self, config: Config) -> None: assert config.padded_vocab_size is not None self.config = config - self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = AdapterV2Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -77,8 +83,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_cache`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -98,9 +107,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.mlp = config.mlp_class(config) @@ -108,14 +122,24 @@ class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" # Copy&paste from :class:`model.CausalSelfAttention` - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + @property + def device(self) -> Optional[torch.device]: + w = self.qkv.linear.weight + return None if w is None else w.device + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" mapping = { diff --git a/litgpt/api.py b/litgpt/api.py index 32cc196603..9ddcd4a783 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -45,7 +45,6 @@ def __init__( checkpoint_dir: Path = None, fabric: L.Fabric = None, generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None, - kv_cache_initialized: bool = False, fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None, ) -> None: super().__init__() @@ -57,7 +56,6 @@ def __init__( self.checkpoint_dir = checkpoint_dir self.fabric = fabric self.generate_strategy = generate_strategy - self.kv_cache_initialized = kv_cache_initialized self.fixed_kv_cache_size = fixed_kv_cache_size self.prev_generated_seq_length = 0 @@ -82,6 +80,9 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): def load_state_dict(self, state_dict, strict=True): return self.model.load_state_dict(state_dict, strict=strict) + def kv_cache_initialized(self) -> bool: + return all(block.attn.kv_cache is not None for block in self.model.transformer.h) + def forward( self, input_ids: torch.Tensor, @@ -249,7 +250,6 @@ def load( checkpoint_dir=checkpoint_dir, fabric=fabric, generate_strategy=None, - kv_cache_initialized=False, fixed_kv_cache_size=False, ) @@ -367,7 +367,6 @@ def distribute( print("Fabric launched", file=sys.stderr) - self.kv_cache_initialized = False if generate_strategy is None: with fabric.init_module(empty_init=(total_devices > 1)): model = GPT(self.config) @@ -383,8 +382,10 @@ def distribute( kv_cache_size = model.max_seq_length else: kv_cache_size = fixed_kv_cache_size - model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device) - self.kv_cache_initialized = True + model.set_kv_caches( + batch_size=1, + max_seq_length=kv_cache_size, + ) self.fixed_kv_cache_size = fixed_kv_cache_size elif generate_strategy in ("sequential", "tensor_parallel"): @@ -437,7 +438,7 @@ def distribute( # the rope cache which is on meta device model.cos, model.sin = model.rope_cache() # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() model = fabric.to_device(model) @@ -448,8 +449,6 @@ def distribute( if fabric.global_rank == 0: pbar.close() - self.kv_cache_initialized = True - else: raise ValueError(f"Unsupported generate_strategy: {generate_strategy}") @@ -508,20 +507,23 @@ def generate( prompt_length = input_ids.size(0) max_returned_tokens = prompt_length + max_new_tokens - if not self.kv_cache_initialized: - if self.fabric is not None: - device = self.fabric.device - else: - device = self.preprocessor.device - self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device) - self.kv_cache_initialized = True + if self.fabric is not None: + device = self.fabric.device + else: + device = self.preprocessor.device + if not self.kv_cache_initialized(): + self.model.set_kv_caches( + batch_size=1, + max_seq_length=max_returned_tokens, + ) # Dynamically grow the kv cache size if necessary if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens: - tmp_device = self.model.mask_cache.device - self.model.clear_kv_cache() - self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device) - + self.model.clear_kv_caches() + self.model.set_kv_caches( + batch_size=1, + max_seq_length=max_returned_tokens, + ) else: for block in self.model.transformer.h: block.attn.kv_cache.reset_parameters() diff --git a/litgpt/attention.py b/litgpt/attention.py new file mode 100644 index 0000000000..708584fa12 --- /dev/null +++ b/litgpt/attention.py @@ -0,0 +1,337 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + +from litgpt.attention_utils import ( + attention_compute_scores, + attention_compute_weighted_values, + build_mask_cache, + build_mask_slice, + filter_sdpa_kernels, +) +from litgpt.config import Config + +# Currently, `torch.nn.functional.scaled_dot_product_attention` does not +# properly support the case `enabla_gqa=True` (i.e., keys and values have +# less heads than queries). In this case, it is best to extend keys and +# values, which requires extra memory, but allows for efficient kernels to +# be used. +# Once PyTorch supports `enabla_gqa=True` properly at least with some fused +# kernels (such as flash attention), this flag can be switched to `False`. +FUSED_SDPA_DOES_NOT_SUPPORT_ENABLE_GQA = True + + +class KeysAndValues: + """ + Object passed to :meth:`MultiHeadSelfAttention.__call__`. Allows to access + keys or values, but (in general) not both at the same time. Implementations + may use the same buffer to return them in the methods below. + + However, if :meth:`both_in_parallel` returns `True`, the tensors returned + by :meth:`keys` and :meth:`values` may be used in parallel, since they are + supported by separate buffers. + + """ + + def keys(self) -> torch.Tensor: + """ + Returns: + keys tensor, shape `(eff_batch_size, n_query_groups, T, head_size)`, + where `T <= cache_length` is the current cache length) + + """ + raise NotImplementedError() + + def values(self) -> torch.Tensor: + """ + Returns: + values tensor, shape `(eff_batch_size, n_query_groups, T, head_size)`, + where `T <= cache_length` is the current cache length) + + """ + raise NotImplementedError() + + def both_in_parallel(self) -> bool: + """ + Returns: + Can use both `keys` and `values` in parallel? Otherwise, can only + use one of them at the same time + """ + return False + + +class DefaultKeysAndValues(KeysAndValues): + def __init__(self, keys: torch.Tensor, values: torch.Tensor): + # The final dimension of K and V can be different (in general) + assert keys.shape[:-1] == values.shape[:-1] and keys.ndim == 4, (keys.shape, values.shape) + self._keys = keys + self._values = values + + def keys(self) -> torch.Tensor: + return self._keys + + def values(self) -> torch.Tensor: + return self._values + + def both_in_parallel(self) -> bool: + """ + Keys and values are supported by different buffers, so they can be + used at the same time. + + """ + return True + + +class MultiHeadSelfAttention: + """ + Maintains code for the inner part of multi-head self-attention which is not + parameterized. This is used both by :class:`CausalSelfAttention` and by the + default KV cache implementation :class:`DefaultKVCache`. + + Kernels to be used for SDPA can be restricted by `sdpa_kernels`. By + default, the choice is down to the method itself. If GPU memory is a + concern (e.g., if MHA is used in training mode, to compute gradients), + `sdpa_kernels=SDPBackend.EFFICIENT_ATTENTION` is recommended. + + If `sdpa_kernels` is used, their availabilities are checked upon the + first call, and a warning is printed if some are not available. + + If `use_eager_sdpa_always=True`, + `torch.nn.functional.scaled_dot_product_attention` is never used. + + """ + + def __init__( + self, + config: Config, + sdpa_kernels: Optional[Union[SDPBackend, List[SDPBackend]]] = None, + use_eager_sdpa_always: bool = False, + ) -> None: + self.config = config + self._sdpa_kernels = sdpa_kernels + self._sdpa_kernels_filtered = False + self.use_eager_sdpa_always = use_eager_sdpa_always + + @property + def sdpa_kernels(self) -> Union[SDPBackend, List[SDPBackend]]: + return self._sdpa_kernels if self._sdpa_kernels is not None else [] + + def set_seq_length( + self, + value: int, + device: torch.device, + ) -> None: + pass # Currently, we don't use this + + def __call__( + self, + query: torch.Tensor, + k_and_v: KeysAndValues, + block_idx: int, + input_pos: Optional[int] = None, + return_attn_weights: bool = False, + token_positions: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + + Args: + query: Queries, shape `(batch_size, n_heads, q_len, head_size)` + k_and_v: Access to keys and values, shape + (batch_size, n_query_groups, kv_len, head_size)` + block_idx: Index of block (or layer) in model + input_pos: Position in input sequence. Defaults to 0 + return_attn_weights: If this is `True` and `input_pos > 0`, the + attention weights (or scores) are returned as second argument + token_positions: Required if `input_pos > 0`. Contains token + positions in KV cache. This is needed to select the correct + part of the mask matrix + + Returns: + `attn_output, attn_weights`, where `attn_weights` is `None` if + attention weights are not returned. + + """ + # We need the attention mask if there is sliding window attention + for_prefill = input_pos == 0 + is_causal = input_pos is None or for_prefill + if not is_causal and token_positions is None: + raise ValueError("token_positions must be given if input_pos > 0") + sliding_window_size = self._get_sliding_window_size(block_idx) + B, _, T, _ = query.shape + mask = None + use_eager_sdpa = self._use_eager_sdpa(return_attn_weights, k_and_v) + if use_eager_sdpa or sliding_window_size is not None or not is_causal: + # Build attention mask + mask_dtype = torch.float32 if use_eager_sdpa else query.dtype + if is_causal: + mask = ( + build_mask_cache( + max_seq_length=T, + sliding_window_size=sliding_window_size, + dtype=mask_dtype, + device=query.device, + ) + .view(1, 1, T, T) + .detach() + ) + elif (not use_eager_sdpa) or T > 1: + # We need a mask if T > 1, since inference needs to be causal + # for the new tokens + mask = build_mask_slice( + input_pos=input_pos, + num=T, + token_positions=token_positions, + n_head=self.config.n_head, + dtype=mask_dtype, + device=query.device, + sliding_window_size=sliding_window_size, + ).detach() + + y, scores = self.scaled_dot_product_attention( + query, + k_and_v, + mask, + return_attn_weights, + ) + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, -1) + return y, scores + + def _get_sliding_window_size(self, block_idx: int) -> Optional[int]: + apply_sliding_window_attention = ( + self.config.sliding_window_size is not None and self.config.sliding_window_indices[block_idx] == 1 + ) + return self.config.sliding_window_size if apply_sliding_window_attention else None + + def _use_eager_sdpa( + self, + return_attn_weights: bool, + k_and_v: KeysAndValues, + ) -> bool: + return ( + return_attn_weights + or self.use_eager_sdpa_always + or self.config.attention_logit_softcapping is not None + or not k_and_v.both_in_parallel() + ) + + def _filter_sdpa_kernels( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + enable_gqa: bool, + **kwargs, + ): + if self._sdpa_kernels is not None and not self._sdpa_kernels_filtered: + if isinstance(self._sdpa_kernels, list): + kernels = self._sdpa_kernels + else: + kernels = [self._sdpa_kernels] + new_kernels = filter_sdpa_kernels( + sdpa_kernels=kernels, + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + enable_gqa=enable_gqa, + ) + self._sdpa_kernels = new_kernels if new_kernels else None + self._sdpa_kernels_filtered = True + + def _get_scale_factor(self): + return 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) + + def scaled_dot_product_attention( + self, + query: torch.Tensor, + k_and_v: KeysAndValues, + mask: Optional[torch.Tensor] = None, + return_scores: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + scale = self._get_scale_factor() + # We cannot call PyTorch scaled_dot_product_attention if: + # - Attention scores need to be returned; or + # - Logit softcapping is required; or + # - We cannot access keys and values from `k_and_v` in parallel + if self._use_eager_sdpa(return_scores, k_and_v): + assert mask is not None or query.shape[2] == 1 + y, scores = scaled_dot_product_attention( + query=query, + k_and_v=k_and_v, + scale=scale, + mask=mask, + attention_logit_softcapping=self.config.attention_logit_softcapping, + ) + if not return_scores: + scores = None + else: + # We need `key` and `value` at the same time here. For the training + # use case, this will be the case, since `k_and_v` is the default + # in this case. + key = k_and_v.keys() + value = k_and_v.values() + is_causal = mask is None + enable_gqa = self.config.n_query_groups < self.config.n_head + if enable_gqa and FUSED_SDPA_DOES_NOT_SUPPORT_ENABLE_GQA: + # Some efficient kernels have not implemented + # `enabla_gqa=True`. It is better to extend keys, values in + # this case. + q_per_kv = self.config.n_head // self.config.n_query_groups + key = key.repeat_interleave(q_per_kv, dim=1) + value = value.repeat_interleave(q_per_kv, dim=1) + enable_gqa = False + kwargs = dict( + query=query, + key=key, + value=value, + attn_mask=mask, + dropout_p=0.0, + scale=scale, + is_causal=is_causal, + enable_gqa=enable_gqa, + ) + self._filter_sdpa_kernels(**kwargs) + if self._sdpa_kernels is not None: + with sdpa_kernel(self._sdpa_kernels): + y = F.scaled_dot_product_attention(**kwargs) + else: + y = F.scaled_dot_product_attention(**kwargs) + scores = None + return y.transpose(1, 2), scores + + +def do_softcapping(x: torch.Tensor, thresh: Optional[float]) -> torch.Tensor: + if thresh is not None: + return torch.tanh(x / thresh) * thresh + else: + return x + + +def scaled_dot_product_attention( + query: torch.Tensor, + k_and_v: KeysAndValues, + scale: float, + mask: Optional[torch.Tensor] = None, + attention_logit_softcapping: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = query.dtype + key = k_and_v.keys().to(torch.float32) + query = query.to(torch.float32) + scores = attention_compute_scores(query, key) * scale + scores = do_softcapping(scores, attention_logit_softcapping) + if mask is not None: + scores = scores + mask.to(torch.float32) + scores = F.softmax(scores, dim=-1) + value = k_and_v.values().to(torch.float32) + return attention_compute_weighted_values(scores, value).to(dtype), scores.to(dtype) diff --git a/litgpt/attention_utils.py b/litgpt/attention_utils.py new file mode 100644 index 0000000000..d9e7179f27 --- /dev/null +++ b/litgpt/attention_utils.py @@ -0,0 +1,243 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from typing import List, Optional + +import torch +from torch.backends.cuda import ( + can_use_cudnn_attention, + can_use_efficient_attention, + can_use_flash_attention, +) +from torch.nn.attention import SDPAParams, SDPBackend + + +def filter_sdpa_kernels( + sdpa_kernels: List[SDPBackend], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + enable_gqa: bool, + **kwargs, +) -> List[SDPBackend]: + params = SDPAParams(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa) + new_kernels = [] + for kernel in sdpa_kernels: + if kernel == SDPBackend.FLASH_ATTENTION and not can_use_flash_attention(params): + continue + elif kernel == SDPBackend.EFFICIENT_ATTENTION and not can_use_efficient_attention(params): + continue + elif kernel == SDPBackend.CUDNN_ATTENTION and not can_use_cudnn_attention(params): + continue + new_kernels.append(kernel) + return new_kernels + + +def attention_compute_scores( + query: torch.Tensor, + key: torch.Tensor, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert query.ndim == key.ndim == 4 + assert query.shape[0] == key.shape[0] and query.shape[3] == key.shape[3] + nh_q = query.shape[1] + nh_k = key.shape[1] + assert nh_q % nh_k == 0 + # - query: (bs, nh_q, T_q, hs) + # - key: (bs, nh_k, T_k, hs) + q_per_kv = nh_q // nh_k + key_transposed = key.mT # (bs, nh_k, hs, T_k) + if q_per_kv == 1: + out = torch.matmul(query, key_transposed, out=out) + else: + assert q_per_kv > 1 + q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:] + _query = query.view(*q_shape) + key_transposed = key_transposed.unsqueeze(2) + # At this point: + # - _query: (bs, nh_k, q_per_kv, T_q, hs) + # - key_transposed: (bs, nh_k, 1, hs, T_k) + # - scores: (bs, nh_k, q_per_kv, T_q, T_k) + if out is not None: + out = out.view(_query.shape[:-1] + (key.shape[2],)) + out = torch.matmul(_query, key_transposed, out=out) + s_shape = query.shape[:-1] + (key.shape[2],) + out = out.view(*s_shape) + return out + + +def attention_compute_weighted_values( + scores: torch.Tensor, + value: torch.Tensor, +) -> torch.Tensor: + assert scores.ndim == value.ndim == 4 + assert scores.shape[0] == scores.shape[0] and scores.shape[3] == value.shape[2] + nh_q = scores.shape[1] + nh_k = value.shape[1] + assert nh_q % nh_k == 0 + # - scores: (bs, nh_q, T_q, T_k) + # - value: (bs, nh_k, T_k, hs) + q_per_kv = nh_q // nh_k + if q_per_kv == 1: + return scores @ value + else: + s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:] + _scores = scores.view(*s_shape) + _value = value.unsqueeze(2) + # At this point: + # - _scores: (bs, nh_k, q_per_kv, T_q, T_k) + # - _value: (bs, nh_k, 1, T_k, hs) + # - result: (bs, nh_k, q_per_kv, T_q, hs) + result = torch.matmul(_scores, _value) + r_shape = scores.shape[:-1] + (value.shape[-1],) + return result.view(*r_shape) + + +def minus_infinity(dtype: torch.dtype) -> float: + return torch.finfo(dtype).min + + +def mask_cache_bool( + max_seq_length: int, + sliding_window_size: Optional[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + # Usual causal mask: + mask = torch.ones( + max_seq_length, + max_seq_length, + device=device, + dtype=dtype, + ).triu(diagonal=1) + if sliding_window_size is not None: + mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size) + return mask + + +def build_mask_cache( + max_seq_length: int, + sliding_window_size: Optional[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Global Window Sliding window Sliding window + attention mask + bias = attention mask + ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ + │ True False False False │ │ True True True True │ │ True False False False │ + │ True True False False │ │ True True True True │ │ True True False False │ + │ True True True False │ │ False True True True │ │ False True True False │ + │ True True True True │ │ False False True True │ │ False False True True │ + └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ + """ + mask = mask_cache_bool(max_seq_length, sliding_window_size, device, dtype) + mask.masked_fill_(mask.bool(), minus_infinity(dtype)) + return mask + + +def mask_slice_bool( + input_pos: int, + num: int, + token_positions: torch.Tensor, + n_head: int, + device: torch.device, + sliding_window_size: Optional[int] = None, +) -> torch.Tensor: + # Build boolean mask, then map False -> 0, True -> -infty + # If (i, j) indexes the complete (seq_len, seq_len) mask matrix, + # causality is given by I(i < j). If `sliding_window_size` is given, + # this translates to I(i >= j + sws) if sws = sliding_window_size. + assert token_positions.ndim == 3 + tp_dtype = token_positions.dtype + batch_size, n_query_groups, _ = token_positions.shape + assert n_head % n_query_groups == 0 and n_head >= n_query_groups + token_positions = ( + token_positions.to(device=device) + .unsqueeze(2) + .expand( + -1, + -1, + num, + -1, + ) + ) + kwargs = dict(device=device, dtype=tp_dtype) + bool_mask = ( + torch.arange( + input_pos, + input_pos + num, + **kwargs, + ) + .view(1, 1, -1, 1) + .expand_as(token_positions) + < token_positions + ) + if sliding_window_size is not None: + extra_mask = ( + torch.arange( + input_pos - sliding_window_size, + input_pos + num - sliding_window_size, + **kwargs, + ) + .view(1, 1, -1, 1) + .expand_as(token_positions) + >= token_positions + ) + bool_mask |= extra_mask + if n_head != n_query_groups: + q_per_kv = n_head // n_query_groups + bool_mask = ( + bool_mask.unsqueeze(2) + .expand( + -1, + -1, + q_per_kv, + -1, + -1, + ) + .reshape(batch_size, n_head, num, -1) + ) + return bool_mask + + +def build_mask_slice( + input_pos: int, + num: int, + token_positions: torch.Tensor, + n_head: int, + dtype: torch.dtype, + device: torch.device, + sliding_window_size: Optional[int] = None, +) -> torch.Tensor: + """ + Returns mask for case `input_pos > 0` in :class:`MultiHeadSelfAttention`. + + Args: + input_pos: Position in input sequence, must be positive + num: Length of query argument `q_len` + token_positions: Token positions in KV cache, shape + `(eff_batch_size, n_query_groups, cache_length)` + n_head: Number of attention heads, must be multiple of + `n_query_groups` + dtype: Data type of the output mask + device: Device of the output mask + sliding_window_size: Size of sliding window (if any) + + Returns: + Mask tensor, shape `(eff_batch_size, n_head, num, cache_length)` + + """ + bool_mask = mask_slice_bool( + input_pos, + num, + token_positions, + n_head, + device, + sliding_window_size, + ) + mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device) + mask.masked_fill_(bool_mask, minus_infinity(dtype)) + return mask diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index 123028b590..b59c8e8f77 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -30,6 +30,7 @@ def generate( prompt: torch.Tensor, max_returned_tokens: int, *, + prompt_chunksize: int = 16, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, @@ -62,20 +63,31 @@ def generate( from litgpt.generate.base import generate_fn return generate_fn( - include_prompt=False, - include_eos=False, model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, + include_prompt=False, + include_eos=False, ) def process_prompt( - prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens + prompt: str, + model: GPT, + tokenizer, + prompt_style, + fabric, + max_new_tokens: int, + prompt_chunksize: int, + temperature: float, + top_k: Optional[int], + top_p: float, + stop_tokens: Tuple[List[int], ...], ): prompt = prompt_style.apply(prompt=prompt) encoded_prompt = tokenizer.encode(prompt, device=fabric.device) @@ -83,16 +95,16 @@ def process_prompt( if max_new_tokens is None: max_returned_tokens = model.max_seq_length else: - first_turn = model.mask_cache is None max_returned_tokens = encoded_prompt.size(0) + max_new_tokens - if first_turn or max_returned_tokens > model.max_seq_length: + msl = model.max_seq_length + if max_returned_tokens > msl or model.config.block_size == msl: model.max_seq_length = max_returned_tokens - model.set_kv_cache(batch_size=1, device=fabric.device) y: Iterator[torch.Tensor] = generate( - model, - encoded_prompt, - max_returned_tokens, + model=model, + prompt=encoded_prompt, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, @@ -111,8 +123,7 @@ def process_prompt( t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() + model.clear_kv_caches() fabric.print( f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens", file=sys.stderr, @@ -120,7 +131,19 @@ def process_prompt( fabric.print() -def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens): +def interact( + multiline: bool, + model: GPT, + tokenizer, + prompt_style, + fabric, + max_new_tokens: int, + prompt_chunksize: int, + temperature: float, + top_k: Optional[int], + top_p: float, + stop_tokens: Tuple[List[int], ...], +): while True: try: if not multiline: @@ -143,7 +166,17 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max break process_prompt( - prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens + prompt=prompt, + model=model, + tokenizer=tokenizer, + prompt_style=prompt_style, + fabric=fabric, + max_new_tokens=max_new_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, ) @@ -152,6 +185,7 @@ def main( checkpoint_dir: Path, *, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -167,6 +201,11 @@ def main( checkpoint_dir: A local path to a directory containing the model weights or a valid model name. You can get a list of valid model names via the `litgpt download list` command line argument. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -223,12 +262,7 @@ def main( with fabric.init_module(empty_init=True): model = GPT(config) - if compile: - print( - "IMPORTANT: with enabled compilation the KV-cache size is determined by model's maximum context size, which leads to " - "a higher memory consumption. In case of an OOM error, try to set `--compile=False`." - ) - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) load_checkpoint(fabric, model, checkpoint_path) model.eval() @@ -261,8 +295,9 @@ def main( tokenizer=tokenizer, prompt_style=prompt_style, fabric=fabric, - temperature=temperature, max_new_tokens=(None if compile else max_new_tokens), + prompt_chunksize=prompt_chunksize, + temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, diff --git a/litgpt/config.py b/litgpt/config.py index 6b7748cf63..3ca40c3c11 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Literal, Optional, Type, Union +from typing import Any, Callable, List, Literal, Optional, Type, Union import yaml from typing_extensions import Self @@ -22,6 +22,13 @@ def find_multiple(n: int, k: int) -> int: return n + k - (n % k) +# See `Config.start_of_layer_hook`. A start of layer hook is called just before +# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where +# `x` is the layer input, `block_idx` the number of the layer, and `input_pos` +# the position in the sequence (see :meth:`GPT.forward`). +StartOfLayerHook = Callable[[Any, int, Optional[int]], None] + + @dataclass class Config: name: str = "" @@ -68,8 +75,13 @@ class Config: n_query_groups: Optional[int] = None attn_bias: bool = False attention_scores_scalar: Optional[int] = None + # If `sliding_window_size` is given, sliding window attention with this + # size is used in layers where `sliding_window_indices` has a 1. The + # default is all 1, so that sliding window attention is used in all + # layers. If `len(sliding_window_indices) > n_layer`, we only use the + # initial part. sliding_window_size: Optional[int] = None - sliding_window_indices: Optional[List] = None + sliding_window_indices: Optional[List[int]] = None # if `attention_logit_softcapping` is used, cannot use optimized # `torch.nn.functional.scaled_dot_product_attention` (which implements # Flash attention), may result in higher memory and runtime footprint. @@ -94,9 +106,17 @@ class Config: norm_1: bool = True norm_2: bool = True # The base period of the RoPE embeddings for local attention. - # If not provided, rope_theta will be used for both local and global attention. + # If not provided, `rope_base` will be used for both local and global attention. rope_local_base_freq: Optional[float] = None - rope_indices: Optional[List] = None + # If provided, must have `>= n_layer` entries, either 0 or 1. For 0, + # `rope_base` is used, for 1 `rope_local_base_freq` is used. If + # `len(rope_indices) > n_layer`, we only use the initial part. + rope_indices: Optional[List[int]] = None + # This hook is called in `GPT.forward` at the start of each layer, + # passing the (detached) layer input, the layer index, and `input_pos`. + # It is also called with the final layer output (which is the input + # into the head block), passing `n_layer` as second argument. + start_of_layer_hook: Optional[StartOfLayerHook] = None def __post_init__(self): if not self.name: @@ -127,11 +147,19 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) - if self.sliding_window_size is not None and self.sliding_window_indices is None: - self.sliding_window_indices = [1] * self.n_layer + if self.sliding_window_size is not None: + self.sliding_window_indices = check_indicator_and_length( + self.sliding_window_indices, + name="sliding_window_indices", + required_length=self.n_layer, + ) - if self.rope_local_base_freq is not None and self.rope_indices is None: - self.rope_indices = [1] * self.n_layer + if self.rope_local_base_freq is not None: + self.rope_indices = check_indicator_and_length( + self.rope_indices, + name="rope_indices", + required_length=self.n_layer, + ) @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: @@ -200,6 +228,25 @@ def norm_class(self) -> Type: return getattr(torch.nn, self.norm_class_name) +def check_indicator_and_length( + params: Optional[List[int]], + name: str, + required_length: int, + use_initial_part: bool = True, + def_val: int = 1, +) -> List[int]: + if params is None: + return [def_val] * required_length + if len(params) != required_length: + if use_initial_part and len(params) > required_length: + params = params[:required_length] + else: + raise ValueError(f"{name} = {params}, must have length {required_length}") + if not set(params).issubset({0, 1}): + raise ValueError(f"{name} = {params}, must only contain 0 and 1") + return params + + ######################## # Stability AI StableLM ######################## diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 42479baba5..6cb386a4ac 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -417,18 +417,22 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) - model.clear_kv_cache() + model.clear_kv_caches() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index c1ba67521c..c542c993d6 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -447,11 +447,15 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) - model.clear_kv_cache() + model.clear_kv_caches() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 22699b8c5c..c67dfc5bf5 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -388,18 +388,22 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) max_returned_tokens = len(encoded) + eval.max_new_tokens if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) - model.clear_kv_cache() + model.clear_kv_caches() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 9593e1d4fe..5017fb7631 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -453,11 +453,15 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E if max_returned_tokens < model.max_seq_length: with fabric.init_tensor(): # do not set `max_seq_length=max_returned_token` because memory is not a concern here - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) - model.clear_kv_cache() + model.clear_kv_caches() model.train() output = tokenizer.decode(output) fabric.print(f"{output}\n") diff --git a/litgpt/generate/adapter.py b/litgpt/generate/adapter.py index fb7f75c5ba..fc0163e75c 100644 --- a/litgpt/generate/adapter.py +++ b/litgpt/generate/adapter.py @@ -33,6 +33,7 @@ def main( adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -55,6 +56,11 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -119,7 +125,7 @@ def main( # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() t0 = time.perf_counter() @@ -134,7 +140,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/adapter_v2.py b/litgpt/generate/adapter_v2.py index e7a65fa528..682afb7b89 100644 --- a/litgpt/generate/adapter_v2.py +++ b/litgpt/generate/adapter_v2.py @@ -33,6 +33,7 @@ def main( adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -55,6 +56,11 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -119,7 +125,7 @@ def main( # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() t0 = time.perf_counter() @@ -134,7 +140,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 565ef08e23..3b817b3c66 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -75,34 +75,50 @@ def sample( def next_token( model: GPT, - input_pos: torch.Tensor, x: torch.Tensor, - input_pos_maxp1: Optional[int] = None, + input_pos: Optional[int], **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: - logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) + logits = model(x, input_pos=input_pos) _next = sample(logits, **sample_kwargs).to(dtype=torch.int64) return _next -def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor: - assert len(logits) == len(kwargs), "logits and kwargs must have the same length." +def batched_sample( + logits_stack: torch.Tensor, + kwargs: Union[dict, list[dict]], +) -> torch.Tensor: + # Unbind the logits stack into a list of logits. + logits = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0) + logits = [l.unsqueeze(0) for l in logits] + _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * len(logits) + assert len(logits) == len(_kwargs), "logits and kwargs must have the same length." return torch.stack( - [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0 + [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(_kwargs, logits)], dim=0 ) def batched_next_token( - model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]] + model: GPT, + x: torch.Tensor, + input_pos: Optional[int], + kwargs: Union[dict, list[dict]], ) -> torch.Tensor: - # Where: - # input_pos is a 1d tensor of shape [seq_length...] - # x is context tokens to add to the kvcache. - # For prefill, x is a 2d tensor of shape [batch_size, prompt_length]. - # For subsequent tokens, x is a 2d tensor of shape [batch_size, 1]. - # kwargs is a list of dictionaries, each containing the keyword arguments for the sample function. - # If one dictionary is passed, it's repeated for each sample in the batch. + """ + Args: + model: GPT model. If `input_pos` is not `None`, its KV caches must be + assigned + x: Context tokens to be used as input, shape `(batch_size, num)`. When + used to sample new tokens, we have `num == 1`. + input_pos: Position of `x` in the full sequence. See + :meth:`GPT.forward` + kwargs: Sampling parameters (can be different for each batch dimension) + + Returns: + New samples corresponding to inputs `x` + + """ # In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length]. # That way, we can support prompts of different sizes. # This means making the rope cache and kvcache forward() work with batches. Currently, they do not. @@ -113,17 +129,11 @@ def batched_next_token( # After this problem is resolved, there will be another problem. That being, continuous batched prefill. # If you have any ideas on this, let me know. I don't think that padding input_pos is viable. - _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0) - # Run the model on the batch. - logits_stack = model(x, input_pos) - - # Unbind the logits stack into a list of logits. - logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0) - logits_list = [l.unsqueeze(0) for l in logits_list] + logits_stack = model(x, input_pos=input_pos) # Return the next token for each sample in the batch. - return batched_sample(logits_list, kwargs=_kwargs) + return batched_sample(logits_stack, kwargs=kwargs) @torch.inference_mode() @@ -131,6 +141,7 @@ def generate_fn( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, temperature: float = 1.0, top_k: Optional[int] = None, @@ -146,6 +157,10 @@ def generate_fn( model: The model to use. prompt: The tokenized prompt to generate from. max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. temperature: The temp to pass to sample(). top_k: The top_k to pass to sample(). top_p: The top_p to pass to sample(). @@ -155,7 +170,13 @@ def generate_fn( """ prompt_size = prompt.size(0) - device = prompt.device + if prompt_size == 0: + raise ValueError("prompt must not be empty") + sample_kwargs = dict( + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) assert max_returned_tokens > prompt_size, ( f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}." @@ -167,31 +188,44 @@ def generate_fn( if include_prompt: yield prompt + # Prompt processing. The first part of the prompt (possibly all of it) + # is processed with a prefill. If the prompt is larger than the KV + # cache length, we need to use sequential processing after that. + max_prefill_length = model.kv_cache_max_prefill_length() + if max_prefill_length is None: + end = prompt_size + else: + end = min(prompt_size, max_prefill_length) + input_pos = 0 + while input_pos < prompt_size: + inputs = prompt[input_pos:end].view(1, -1) + # We may need the last time slice of `all_logits` below: + all_logits = model(inputs, input_pos=input_pos) + input_pos = end + # Note that `max_tokens_forward` can change during the course of + # prompt processing: + chunksize = min((prompt_chunksize, model.kv_cache_max_tokens_forward(), prompt_size - input_pos)) + end += chunksize + + # Generation loop: One token per iteration + tokens = [] stop_progress = [0] * len(stop_tokens) yielded_idx = 0 - - # Generate output tokens. - # The first token generated is the prefill token. - # The input_pos for this token is the width of the entire prompt. - # For subsequent iterations, it's the index in the context for the token that we're generating. - tokens = [] - token = prompt - prefill_token = True - input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - # input_pos_maxp1 introduces data-dependent shapes and control flow. - # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. - input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token - token = next_token( - model, - input_pos, - token.view(1, -1), - input_pos_maxp1=input_pos_maxp1, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + if current_idx == 0: + # First token sampled from the final logits output for prompt + # processing + token = sample(all_logits, **sample_kwargs).to(dtype=torch.int64) + all_logits = None + else: + token = next_token( + model=model, + x=token.view(1, -1), + input_pos=input_pos, + **sample_kwargs, + ) + input_pos += 1 tokens.append(token) int_token = token.item() @@ -221,15 +255,6 @@ def generate_fn( yield from y_tokens yielded_idx = safe_idx - # Update input_pos for the next iteration. - if prefill_token: - prefill_token = False - input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) - else: - input_pos.add_(1) - if input_pos_maxp1 is not None: - input_pos_maxp1 += 1 - # Yield any remaining tokens if yielded_idx < len(tokens): yield from tokens[yielded_idx:] @@ -242,23 +267,30 @@ def batched_generate_fn( model: GPT, prompts: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, sample_args: Union[list[dict], dict], stop_tokens: Tuple[List[int], ...] = (), include_prompt: bool, - include_eos: bool, ) -> Iterator[list[Union[torch.Tensor, None]]]: """ Generates tokens for a batch of prompts. Args: model: The model to use. - prompts: A 2D tensor of shape [batch_size, prompt_length]. - max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. - sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. - stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. + prompts: A 2D tensor of shape [batch_size, prompt_length]. Note that + all prompts need to have the same length (TODO: Relax this) + max_returned_tokens: The maximum number of tokens to return, including + the prompt tokens. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. + sample_args: The dictionary of kwargs to pass to sample() for each + token for each index in the batch. + stop_tokens: A tuple of stop sequences. If any of the sequences are + generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. - include_eos: Whether to output the stop tokens if generation stops early. Yields: A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch. @@ -268,12 +300,10 @@ def batched_generate_fn( prompts = prompts.unsqueeze(0) assert prompts.ndim == 2, "Prompts must be a 2D tensor." - batch_size = prompts.size(0) - max_prompt_size = prompts.size(1) - device = prompts.device + batch_size, max_prompt_size = prompts.shape if isinstance(sample_args, dict): - sample_args = [sample_args] * len(prompts) + sample_args = [sample_args] * batch_size else: assert len(sample_args) == batch_size, "sample_args must have the length as the batch size." @@ -290,22 +320,42 @@ def batched_generate_fn( for i in range(max_prompt_size): yield [prompt[i].view(-1) for prompt in prompts] + # Prompt processing. The first part of the prompt (possibly all of it) + # is processed with a prefill. If the prompt is larger than the KV + # cache length, we need to use sequential processing after that. + max_prefill_length = model.kv_cache_max_prefill_length() + if max_prefill_length is None: + end = max_prompt_size + else: + end = min(max_prompt_size, max_prefill_length) + input_pos = 0 + while input_pos < max_prompt_size: + inputs = prompts[:, input_pos:end] + # We may need the last time slice of `all_logits` below: + all_logits = model(inputs, input_pos=input_pos) + input_pos = end + # Note that `max_tokens_forward` can change during the course of + # prompt processing: + chunksize = min((prompt_chunksize, model.kv_cache_max_tokens_forward(), max_prompt_size - input_pos)) + end += chunksize + stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)] stop_idxes = [-1] * batch_size yielded_idx = 0 - # Generate output tokens. - # The first token generated is the prefill token. - # The input_pos for this token is the width of the entire prompt. - # For subsequent iterations, it's the index in the context for the token that we're generating. + # Generation loop: One token per iteration token_lists = [[] for _ in range(batch_size)] - tokens: torch.Tensor = prompts - prefill_token = True - input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64) for current_idx in range(max_returned_tokens - max_prompt_size): - # Generate the next token for each prompt in the batch. - # This is of shape [batch_size, 1]. - tokens = batched_next_token(model, input_pos, tokens, sample_args) + if current_idx == 0: + tokens = batched_sample(all_logits[:, -1:], kwargs=sample_args) + else: + tokens = batched_next_token( + model=model, + x=tokens, + input_pos=input_pos, + kwargs=sample_args, + ) + input_pos += 1 for i in range(batch_size): token_lists[i].append(tokens[i]) int_tokens = [token.item() for token in tokens] @@ -347,16 +397,6 @@ def batched_generate_fn( yield y_tokens yielded_idx = safe_idx - # Update input_pos for the next iteration. - if prefill_token: - prefill_token = False - - # TODO: Make the model support a batched input_pos of shape [batch_size, 1]. - # The kvcache has been fixed, but the rope cache is still broken. - input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64) - else: - input_pos.add_(1) - # Yield any remaining tokens max_token_lists = max(len(l) for l in token_lists) if yielded_idx < max_token_lists: @@ -375,6 +415,7 @@ def generate( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, temperature: float = 1.0, top_k: Optional[int] = None, @@ -390,6 +431,10 @@ def generate( model: The model to use. prompt: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. @@ -417,6 +462,7 @@ def generate( model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, @@ -435,6 +481,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -452,6 +499,10 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -520,7 +571,7 @@ def main( # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() if compile: @@ -540,9 +591,10 @@ def main( for i in range(num_samples): t0 = time.perf_counter() y = generate( - model, - encoded, - max_returned_tokens, + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index 78cc8dde7d..23f2f57bc3 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -32,6 +32,7 @@ def main( finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -54,7 +55,12 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. - top_k: The number of top most probable tokens to consider in the sampling process. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. + top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, @@ -117,7 +123,7 @@ def main( # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() model = fabric.setup(model) @@ -129,7 +135,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 1aecccbfc2..f2201e0a80 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -2,7 +2,6 @@ import itertools import logging -import math import re import sys import time @@ -11,7 +10,7 @@ from functools import partial from pathlib import Path from pprint import pprint -from typing import Literal, Optional, Type +from typing import List, Literal, Optional, Type import lightning as L import torch @@ -23,7 +22,7 @@ import litgpt.generate.base as generate_base from litgpt.config import Config -from litgpt.model import GPT, Block, build_mask_cache +from litgpt.model import GPT, Block from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision @@ -37,18 +36,12 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int f" n_layer={model.config.n_layer} and devices={devices}." ) - # The last device might get fewer layers if number of layers not evenly divisible by device count - max_layers_per_device = math.ceil(model.config.n_layer / devices) - # dictates where each block should be instantiated - mapping = layer_to_device(model, chunk_on=Block, chunk_size=max_layers_per_device) - - if set(mapping.values()) != set(range(devices)): - # TODO: support smarter partitioning schemes - raise RuntimeError( - f"Not able to distribute the {model.config.n_layer} layers across {devices} devices." - " Try running with a lower number of devices." - ) - + # Dictates where each block should be instantiated + mapping = layer_to_device( + model, + chunk_on=Block, + chunk_sizes=chunk_sizes(model.config.n_layer, devices), + ) num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)} # materialize each block on the appropriate device @@ -64,17 +57,12 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int replace_device(submodule, replace=torch.device("cpu"), by=target_device) # in case the checkpoint was partial, materialize leftover metas _materialize_meta_tensors(submodule, target_device) - # and build the kv cache - submodule.attn.kv_cache = submodule.attn.build_kv_cache( - 1, max_seq_length, model.cos.size(-1), target_device - ) # rebuild odd ends with root: + # Setting `max_seq_length` forces other members to be built + if model.max_seq_length >= max_seq_length: + model.reset_caches() model.max_seq_length = max_seq_length - # the rope cache which is on meta device - model.cos, model.sin = model.rope_cache() - # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers - model.mask_cache = build_mask_cache(max_seq_length) # and everything that is not a block in the root _materialize_meta_tensors(model, root) replace_device(model, replace=torch.device("cpu"), by=root) @@ -96,13 +84,25 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int return model +def chunk_sizes(num_units: int, devices: int) -> List[int]: + cs = num_units // devices + k = devices * (cs + 1) - num_units + return [cs] * k + [cs + 1] * (devices - k) + + def layer_to_device( - module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_size: int + module: torch.nn.Module, + chunk_on: Type[torch.nn.Module], + chunk_sizes: List[int], ) -> "OrderedDict[str, int]": """Create a mapping from layer (block) to device.""" # this assumes that the definition order is the same as the execution order hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)] - return OrderedDict((name, i // chunk_size) for i, name in enumerate(hits)) + if sum(chunk_sizes) != len(hits): + raise ValueError(f"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}") + _devices = [[d] * cs for d, cs in enumerate(chunk_sizes)] + devices = [d for lst in _devices for d in lst] + return OrderedDict(zip(hits, devices)) def move_block_input(device: torch.device, module: torch.nn.Module, ins): @@ -141,6 +141,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -158,6 +159,11 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -227,6 +233,7 @@ def main( # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) + model.set_kv_caches(batch_size=1) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) t0 = time.perf_counter() @@ -253,26 +260,25 @@ def main( torch._inductor.config.coordinate_descent_tuning = True # cannot use cudagraphs because it doesn't support multiple device indices # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375 - generate_base.next_token = torch.compile(generate_base.next_token) L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( - model, - encoded, - max_returned_tokens, + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) + model.clear_kv_caches() print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) diff --git a/litgpt/generate/speculative_decoding.py b/litgpt/generate/speculative_decoding.py index 653583b67f..cdbdba06e7 100644 --- a/litgpt/generate/speculative_decoding.py +++ b/litgpt/generate/speculative_decoding.py @@ -5,7 +5,7 @@ import warnings from pathlib import Path from pprint import pprint -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple import lightning as L import torch @@ -16,7 +16,11 @@ from lightning_utilities.core.imports import RequirementCache from litgpt.config import Config -from litgpt.generate.base import multinomial_num_samples_1, next_token, sample_top_p +from litgpt.generate.base import ( + multinomial_num_samples_1, + sample_top_p, +) +from litgpt.kvcache import DenseKVCache from litgpt.model import GPT from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer @@ -35,7 +39,7 @@ def sample( top_k: Optional[int] = None, top_p: float = 1.0, apply_softmax: bool = True, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: if top_p < 0.0 or top_p > 1.0: raise ValueError(f"top_p must be in [0, 1], got {top_p}") logits = logits[0, -1] @@ -57,14 +61,53 @@ def sample( return torch.argmax(logits, dim=-1, keepdim=True), F.softmax(logits, dim=-1) +def support_speculative_decoding(model: GPT) -> bool: + """ + Does this model support speculative decoding? This depends mostly on + the KV caches for this model. + + Args: + model: GPT model + + Returns: + Does model support speculative decoding? + + """ + # Hack to make unit tests in + # tests/test_generate_speculatively.py work, which use DraftModel inplace + # of GLM + if not isinstance(model, GPT): + return True + caches = [block.attn.kv_cache for block in model.transformer.h] + if any(c is None for c in caches): + raise ValueError("Some KV caches are not assigned. Use 'model.assign_kv_caches' or 'model.set_kv_caches'") + result = all(isinstance(c, DenseKVCache) for c in caches) + if not result: + print( + "Speculative decoding currently supported only for DenseKVCache " + "KV caches, which support removing most recently inserted " + "information by 'resize'." + ) + return result + + +def _resize_kv_caches(model: GPT, new_length: int): + # Hack to make unit tests in + # tests/test_generate_speculatively.py work, which use DraftModel inplace + # of GLM + if not isinstance(model, GPT): + return True + for block in model.transformer.h: + block.attn.kv_cache.resize(new_length) + + def speculative_decoding( draft_model: GPT, target_model: GPT, token: torch.Tensor, - input_pos: torch.Tensor, - input_pos_maxp1: int, + input_pos: int, speculative_k: int, - **sample_kwargs: Dict[str, Any], + **sample_kwargs, ) -> torch.Tensor: """Performs speculative decoding using a draft and a target model. @@ -84,8 +127,8 @@ def speculative_decoding( draft_model: Smaller/faster model used for initial token predictions target_model: Larger/slower model used for verification token: Current input token tensor of shape [1] - input_pos: Position index of the token tensor for KV-cache - input_pos_maxp1: Maximum position + 1 for managing KV-cache buffer + input_pos: Position in sequence where to start generating tokens. + Required by KV cache speculative_k: Number of tokens to speculatively generate at once sample_kwargs: Additional sampling parameters (temperature, top_k, top_p) @@ -96,20 +139,21 @@ def speculative_decoding( if speculative_k < 1: raise ValueError(f"speculative_k must be >= 1, got {speculative_k}") + if not (support_speculative_decoding(draft_model) and support_speculative_decoding(target_model)): + raise ValueError("Both draft_model and target_model must have DenseKVCache KV caches only") # Step 1: Generate candidate tokens using draft model # The draft model autoregressively generates k tokens, keeping track of probabilities - draft_input_pos = input_pos.clone() - draft_input_pos_maxp1 = input_pos_maxp1 + draft_input_pos = input_pos draft_tokens, draft_probs = [], [] draft_token = token for idx in range(speculative_k): logits = draft_model( - idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1 + idx=draft_token.unsqueeze(0), + input_pos=draft_input_pos, ) draft_token, draft_prob = sample(logits, **sample_kwargs) - draft_input_pos.add_(1) - draft_input_pos_maxp1 += 1 + draft_input_pos += 1 draft_tokens.append(draft_token) draft_probs.append(draft_prob) draft_tokens = torch.cat(draft_tokens) @@ -117,10 +161,9 @@ def speculative_decoding( # Step 2: Get target model predictions for comparison # Feed both original token and draft tokens to get target probabilities candidate_tokens = torch.cat((token, draft_tokens)) - candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device) - candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k target_logits = target_model( - idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1 + idx=candidate_tokens.unsqueeze(0), + input_pos=input_pos, ) # Step 3: Convert target logits to probabilities using same sampling params @@ -134,6 +177,7 @@ def speculative_decoding( # Otherwise reject with probability 1 - target_prob / draft_prob. # If rejected, sample from an adjusted distribution: norm(max(0, target_prob_distribution - draft_prob_distribution) instead. accepted_tokens = [] + new_token = None for idx in range(len(draft_tokens)): draft_token = draft_tokens[idx].unsqueeze(0) draft_prob = draft_probs[idx][draft_token] @@ -158,22 +202,58 @@ def speculative_decoding( adjusted_distribution = torch.clamp(adjusted_distribution, 0.0) adjusted_distribution = adjusted_distribution / adjusted_distribution.sum() new_token, _ = sample(adjusted_distribution[None, None, ...], apply_softmax=False, **sample_kwargs) - return torch.cat((*accepted_tokens, new_token)) - - # If all draft tokens were accepted: - # 1. Update draft model's key-value cache - # 2. Sample one more token from target model - draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1) - new_token, _ = sample(target_logits, **sample_kwargs) + break + + # At this point, the draft model has advanced to `input_pos + speculative_k`, + # the target model to `input_pos + speculative_k + 1`. This needs to be + # corrected if not all speculative tokens are accepted. + if new_token is None: + # All speculative tokens have been accepted. The draft model has to + # extend one more, and another token has to be sampled + draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos) + new_token, _ = sample(target_logits, **sample_kwargs) + else: + input_pos += len(accepted_tokens) + 1 + _resize_kv_caches(draft_model, input_pos) + _resize_kv_caches(target_model, input_pos) return torch.cat((*accepted_tokens, new_token)) +def _process_prompt( + model: GPT, + prompt: torch.Tensor, + prompt_chunksize: int, + **sample_kwargs, +) -> torch.Tensor: + prompt_size = prompt.size(0) + assert prompt_size > 0 + max_prefill_length = model.kv_cache_max_prefill_length() + if max_prefill_length is None: + end = prompt_size + else: + end = min(prompt_size, max_prefill_length) + input_pos = 0 + logits = None + while input_pos < prompt_size: + inputs = prompt[input_pos:end].view(1, -1) + logits = model(inputs, input_pos=input_pos) + input_pos = end + # Note that `max_tokens_forward` can change during the course of + # prompt processing: + chunksize = min((prompt_chunksize, model.kv_cache_max_tokens_forward(), prompt_size - input_pos)) + end += chunksize + # Sample single token + token, _ = sample(logits, **sample_kwargs) + return token + + @torch.inference_mode() def generate( draft_model: GPT, target_model: GPT, prompt: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, temperature: float = 1.0, top_k: Optional[int] = None, @@ -181,7 +261,7 @@ def generate( stop_tokens: Tuple[List[int], ...] = (), include_prompt: bool = True, speculative_k: int, -) -> Iterator[torch.Tensor]: +) -> Tuple[torch.Tensor, float]: """Generates tokens using speculative decoding with a draft and a target model. This function implements token generation using speculative decoding, where a faster draft model @@ -192,6 +272,10 @@ def generate( target_model: Larger/more accurate model used to verify draft predictions prompt: Input tensor of token ids to generate from, shape [sequence_length] max_returned_tokens: Maximum total tokens (prompt + generated) to return + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. temperature: Sampling temperature (higher = more random, lower = more deterministic) top_k: If set, only sample from the top k most likely next tokens top_p: If <1.0, only sample from tokens whose cumulative probability exceeds top_p @@ -211,8 +295,8 @@ def generate( 5. Process repeats until max tokens or stop sequence reached """ + prompt = prompt.flatten() prompt_size = prompt.size(0) - device = prompt.device assert max_returned_tokens > prompt_size, ( f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}." @@ -227,39 +311,26 @@ def generate( ) # Step 1: Prefill draft and target models with the prompt. - input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. - input_pos_maxp1 = ( - prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None - ) - next_token( - draft_model, - input_pos, - prompt.view(1, -1), - input_pos_maxp1=input_pos_maxp1, + sample_kwargs = dict( temperature=temperature, top_k=top_k, top_p=top_p, ) - token = next_token( + _process_prompt(draft_model, prompt, prompt_chunksize, **sample_kwargs) + token = _process_prompt( target_model, - input_pos, - prompt.view(1, -1), - input_pos_maxp1=input_pos_maxp1, - temperature=temperature, - top_k=top_k, - top_p=top_p, + prompt, + prompt_chunksize, + **sample_kwargs, ) - # Update position trackers after prompt - input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) - input_pos_maxp1 += 1 + input_pos = prompt_size # Step 2: Main generation loop. tokens = [] total_generated, total_accepted = 0, 0 # Track acceptance statistics while input_pos < max_returned_tokens - 1: # Calculate speculative tokens to generate - _speculative_k = min(speculative_k, (max_returned_tokens - input_pos - 1).item()) + _speculative_k = min(speculative_k, max_returned_tokens - input_pos - 1) # Get new tokens via speculative decoding new_tokens = speculative_decoding( @@ -267,11 +338,8 @@ def generate( target_model=target_model, token=token, input_pos=input_pos, - input_pos_maxp1=input_pos_maxp1, speculative_k=_speculative_k, - temperature=temperature, - top_k=top_k, - top_p=top_p, + **sample_kwargs, ) # Update statistics @@ -291,8 +359,9 @@ def generate( break # Update positions for next iteration - input_pos.add_(accepted_tokens_len) - input_pos_maxp1 += accepted_tokens_len + # UUPS! input_pos must increase by _speculative_k, otherwise KV + # caches do not work! + input_pos += accepted_tokens_len token = new_tokens[-1].unsqueeze(0) # Finalize generated sequence @@ -311,7 +380,7 @@ def setup_model(config: Config, max_returned_tokens: int, fabric: L.Fabric) -> G # set the max_seq_length to limit the memory usage to what we need model.max_seq_length = max_returned_tokens # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() return fabric.setup_module(model) @@ -334,6 +403,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, speculative_k: int = 3, top_k: Optional[int] = 50, top_p: float = 1.0, @@ -347,12 +417,16 @@ def main( Generates text samples based on pre-trained models and a tokenizer. Args: - draft_model: Smaller/faster model used for initial token predictions - target_model: Larger/more accurate model used to verify draft predictions + draft_model_checkpoint_dir: Smaller/faster model used for initial token predictions + target_model_checkpoint_dir: Larger/more accurate model used to verify draft predictions prompt: The prompt string to use for generating the samples. sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. speculative_k: Number of tokens to speculatively generate at each step top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. @@ -428,14 +502,6 @@ def main( target_model = setup_model(target_config, max_returned_tokens, fabric) fabric.print(f"Time to instantiate models: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) - # Setup compilation if needed - if compile: - torch._dynamo.config.automatic_dynamic_shapes = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.coordinate_descent_tuning = True - global next_token - next_token = torch.compile(next_token, mode="reduce-overhead") - # Load model weights t0 = time.perf_counter() load_checkpoint(fabric, draft_model, draft_checkpoint_path) @@ -451,10 +517,11 @@ def main( target_model, encoded, max_returned_tokens, + prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, - stop_tokens=([tokenizer.eos_id] if tokenizer.eos_id is not None else []), + stop_tokens=([tokenizer.eos_id] if tokenizer.eos_id is not None else [],), speculative_k=speculative_k, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index 16bd5ac878..532ef2963b 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -107,6 +107,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -124,6 +125,11 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -197,6 +203,7 @@ def main( # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) + model.set_kv_caches(batch_size=1) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device @@ -224,7 +231,7 @@ def main( # the rope cache which is on meta device model.cos, model.sin = model.rope_cache() # enable the kv cache - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) model.eval() t0 = time.perf_counter() @@ -236,21 +243,26 @@ def main( torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True - generate_base.next_token = torch.compile(generate_base.next_token, mode="reduce-overhead") L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) + model.clear_kv_caches() if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) diff --git a/litgpt/kvcache/__init__.py b/litgpt/kvcache/__init__.py new file mode 100644 index 0000000000..4d8c403227 --- /dev/null +++ b/litgpt/kvcache/__init__.py @@ -0,0 +1,14 @@ +from litgpt.kvcache.base import ( + DefaultKVCache, + KVCache, + KVCacheParams, +) +from litgpt.kvcache.baselines import DenseKVCache, LastRecentlyInsertedKVCache + +__all__ = [ + "DefaultKVCache", + "DenseKVCache", + "KVCache", + "KVCacheParams", + "LastRecentlyInsertedKVCache", +] diff --git a/litgpt/kvcache/base.py b/litgpt/kvcache/base.py new file mode 100644 index 0000000000..7de9599da9 --- /dev/null +++ b/litgpt/kvcache/base.py @@ -0,0 +1,485 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn.attention import SDPBackend + +from litgpt.attention import ( + DefaultKeysAndValues, + KeysAndValues, + MultiHeadSelfAttention, +) +from litgpt.config import Config + + +@dataclass(frozen=True) +class KVCacheParams: + batch_size: int + n_query_groups: int + cache_length: int + head_size: int + n_head: int + device: Optional[torch.device] + dtype: Optional[torch.dtype] + + @staticmethod + def from_config( + config: Config, + batch_size: int, + cache_length: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ) -> "KVCacheParams": + if head_size is None: + head_size = config.n_embd // config.n_head + return KVCacheParams( + batch_size=batch_size, + n_query_groups=config.n_query_groups, + cache_length=cache_length, + head_size=head_size, + n_head=config.n_head, + device=device, + dtype=dtype, + ) + + +class KVCache(torch.nn.Module): + """ + Base class for key-value (KV) caches. + + Buffers have shapes + `(batch_size, config.n_query_groups, cache_length, head_size)`, where + `head_size` is a parameter. Caching can be used for + batch size `1 <= eff_batch_size <= batch_size`, which is determined in + prefill calls (`input_pos=0`) of :meth:`forward`. + + Note: In general, query and key tensors need to be position-encoded + (e.g., RoPE). + + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ): + """ + Note that `batch_size` is the maximum batch size the cache can be used + with. The effective batch size is determined when calling + :meth:`forward` with `input_pos=0`, and can change with any such prefill + call. If this is smaller than `batch_size`, then in general only parts + of the buffers are used. + + Args: + config: Model config + batch_size: Inference batch size (maximum) + cache_length: Number of slots in cache + block_idx: Index of model block (or layer). Multi-head attention + needs to know this. + dtype: Data type for buffers + head_size: Size of final dimension of buffers. Defaults to head + size of model + """ + super().__init__() + if cache_length <= 0: + raise ValueError("cache_length must be positive") + self.batch_size = batch_size + self._n_query_groups = config.n_query_groups + self._cache_length = cache_length + if head_size is None: + head_size = config.head_size + self.head_size = head_size + self.n_head = config.n_head + self._dtype = dtype + self.block_idx = block_idx + # TODO: Remove once HuggingFace bug is fixed + # https://github.com/huggingface/transformers/issues/35233 + # https://github.com/huggingface/transformers/pull/35901 + self._work_around_hf_bug = config.rope_n_elem == 1 + + @property + def device(self) -> torch.device: + """ + Returns: + Device the KV cache buffers are kept on + + """ + raise NotImplementedError + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def cache_length(self) -> Optional[int]: + return self._cache_length + + @property + def n_query_groups(self) -> int: + return self._n_query_groups + + @property + def next_token_pos(self) -> Optional[int]: + """ + Returns: + Input position for next token to be generated, or `None` if cache + has not been initialized yet (call of :meth:`prefill`). + """ + raise NotImplementedError() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + input_pos: int, + ) -> torch.Tensor: + """ + Given query, key, value tensors, this method extends the KV cache with + `key`, `value`, then computes multi-head self attention. There are two + cases: + + * Prefill (`input_pos == 0`): Starts a generation loop by passing key + and value tensors. The KV cache is reset. The length must be + `num <= max_prefill_length`. The effective batch size must be + `eff_batch_size <= batch_size`. This batch size is then fixed for + subsequent calls of :meth:`forward`. + Different to update (`input_pos > 0`), additional information (such + as attention weights) are not obtained here. This is because the + typical prefill size is much larger than `num` in update, and device + memory is much more of a concern. + * Update (`input_pos > 0`): Continues a generation loop (or processing + of large prompt). The length must be `num <= max_tokens_forward`. + + If the cache makes eviction decisions based on scores which require + attention weights, scores for the next :meth:`forward` call need to + be computed here. + + If a sequence is generated token by token, updates always use `num=1`. + The case `num > 1` arises if large prompts are to be ingested with more + than `max_prefill_length` tokens. Note that if the cache makes eviction + decisions by scoring in :meth:`update`, then large `num` may lead to + worse decisions. On the other hand, ingesting prompts with larger `num` + is faster. + + Args: + query: New queries, + `(eff_batch_size, n_query_groups, num, head_size)`. Here, + `num <= max_tokens_forward` if `input_pos > 0`, and + `num <= max_prefill_length` if `input_pos == 0`. Must be + position encoded. + key: New keys, `(eff_batch_size, n_query_groups, num, head_size)`. + Must be position encoded. + value: New values, `(eff_batch_size, n_query_groups, num, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, num)`. + Some KV caches make use of this information. + input_pos: Token position of the new chunk in the full input + sequence. + + Returns: + Multi-head self-attention outputs before final linear map, + `(eff_batch_size, num, n_head * head_size)` + + """ + raise NotImplementedError() + + def get_keys_values(self) -> Optional[KeysAndValues]: + """ + Returns: + :class:`KeysAndValues` object, providing access to currently stored + keys and values tensors. If the cache is empty or has not been + initialized, `None` is returned. + + """ + raise NotImplementedError() + + @property + def max_tokens_forward(self) -> int: + """ + Note that this limit may change during the course of the generation + for certain caches. + + Returns: + Maximum number of token positions which can be treated in + :meth:`forward` with `input_pos > 0`. Depends on cache, but is + `<= cache_length` + + """ + raise NotImplementedError() + + @property + def max_prefill_length(self) -> Optional[int]: + """ + Returns: + Maximum sequence length for `key`, `value` tensors passed to + :meth:`forward` if `input_pos == 0`. If there is no such maximum + length, `None` is returned. + + """ + raise NotImplementedError() + + def get_params(self) -> KVCacheParams: + return KVCacheParams( + batch_size=self.batch_size, + n_query_groups=self.n_query_groups, + cache_length=self.cache_length, + head_size=self.head_size, + n_head=self.n_head, + device=self.device, + dtype=self.dtype, + ) + + def token_positions(self) -> torch.Tensor: + """ + Returns: + Token positions in slots of the cache, shape + `(eff_batch_size, n_query_groups, T)`.where `T <= cache_length` + is the current cache length. + """ + raise NotImplementedError() + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + """ + This is an estimate of the main buffers (which should all be allocated + up front), it does not cover temporary storage used in the methods + (make sure these are small compared to the main buffers). Also, real + memory usage may be larger due to alignment issues. + + Returns: + num_bits_total, bits_by_part (unit is bit) + + """ + raise NotImplementedError() + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + """ + Same semantics as :meth:`size_estimate`, but can be called without a + cache being created. Results may not be exactly the same, but should + be very close. + + Args: + params: KV cache parameters + **kwargs: Extra arguments (optional) + + Returns: + num_bits_total, bits_by_part (unit is bit) + + """ + raise NotImplementedError() + + def reset_parameters(self) -> None: + pass + + +class DefaultKVCache(KVCache): + """ + Default implementation of :class:`KVCache`, which implements :meth:`forward` + using scaled dot product attention. Most KV caches will inherit from this + class. + + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + mha: Optional[MultiHeadSelfAttention] = None, + sdpa_kernels: Optional[Union[SDPBackend, List[SDPBackend]]] = None, + use_eager_sdpa_always: bool = False, + ): + super().__init__( + config=config, + batch_size=batch_size, + cache_length=cache_length, + block_idx=block_idx, + dtype=dtype, + head_size=head_size, + ) + if mha is None: + self.mha = MultiHeadSelfAttention( + config, + sdpa_kernels=sdpa_kernels, + use_eager_sdpa_always=use_eager_sdpa_always, + ) + else: + self.mha = mha + + @property + def eff_batch_size(self) -> Optional[int]: + raise NotImplementedError() + + def _forward_check_args( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + input_pos: int, + ): + for_prefill = input_pos == 0 + if query.ndim != 4: + raise ValueError("query, key, value must be 4D tensors") + eff_batch_size, _, num, _ = query.shape + if for_prefill: + if not (1 <= eff_batch_size <= self.batch_size): + raise ValueError(f"query.shape[0] = {eff_batch_size}, must be in [1, {self.batch_size}]") + if self.max_prefill_length is not None and not (1 <= num <= self.max_prefill_length): + raise ValueError(f"query.shape[2] = {num}, must be in [1, {self.max_prefill_length}]") + else: + if eff_batch_size != self.eff_batch_size: + raise ValueError(f"query.shape[0] = {eff_batch_size} != eff_batch_size = {self.eff_batch_size}") + if not (1 <= num <= self.max_tokens_forward): + raise ValueError(f"query.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + q_shape = (eff_batch_size, self.n_head, num, self.head_size) + if query.shape != q_shape: + raise ValueError(f"query.shape = {query.shape}, must be {q_shape}") + k_shape = (eff_batch_size, self.n_query_groups, num, self.head_size) + if key.shape != k_shape: + raise ValueError(f"key.shape = {key.shape}, must be {k_shape}") + if value.shape != k_shape: + raise ValueError(f"value.shape = {value.shape}, must be {k_shape}") + t_shape = (eff_batch_size, num) + if token_idx.shape != t_shape: + raise ValueError(f"token_idx.shape = {token_idx.shape}, must be {t_shape}") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + input_pos: int, + ) -> torch.Tensor: + self._forward_check_args(query, key, value, token_idx, input_pos) + for_prefill = input_pos == 0 + num = query.shape[2] + self.mha.set_seq_length(input_pos + num, device=query.device) + + # Call :meth:`_forward` or :meth:`_prefill`, depending on `for_prefill` + if for_prefill: + self._prefill(key, value, token_idx) + # In this case, `k_and_v` can vend both keys and values at the same + # time. + k_and_v = DefaultKeysAndValues(key, value) + else: + # Extend KV cache and retrieve key, value tensors to be used. + # Instead of asking for the key and value tensors as such, + # `k_and_v` allows access to them. Since they are never needed at + # the same time, this can save memory. + k_and_v = self._forward(key, value, token_idx) + + # Multi-head self-attention main computation + return_attn_weights = (not for_prefill) and self.update_requires_attn_weights() + y, scores = self.mha( + query=query, + k_and_v=k_and_v, + block_idx=self.block_idx, + input_pos=input_pos, + return_attn_weights=return_attn_weights, + token_positions=None if for_prefill else self.token_positions(), + ) + if scores is not None and return_attn_weights: + # Pass attention weights to KV cache + self._update(attn_weights=scores) + + return y + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + """ + Implements part of :meth:`forward` if `input_pos > 0`. Namely, `key` + and `value` are written into the cache, possibly evicting slots. Then, + an object is returned which provides read access to the full keys and + values buffers. + + Args: + key: New keys, `(eff_batch_size, n_query_groups, num, head_size)`, + where `1 <= num <= max_tokens_forward` + value: New values, `(eff_batch_size, n_query_groups, num, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, num)`. + Some KV caches make use of this information. + + Returns: + key_cached, value_cached, `(eff_batch_size, n_query_groups, T, + head_size)`, where `T <= cache_length` is the current cache + length + + """ + raise NotImplementedError() + + def _update(self, *args, **kwargs): + """ + Method called in :meth:`forward` if `input_pos > 0`, passing extra + information depending on the subclass. In general, this method updates + internal scores and takes a decision which slot is evicted upon the + next :meth:`forward` call, if the cache is full. + + One important example are KV caches based on the Heavy Hitter Oracle + (H2O) proposal. These require the attention weights from the current + MLA computation to be passed, and :meth:`update_requires_attn_weights` + has to return `True`. + + Note: The extra information typically scales with `num`, the number of + tokens :meth:`forward` was called for. + + Args: + *args: Depends on subclass + **kwargs: Depends on subclass + + """ + raise NotImplementedError() + + def update_requires_attn_weights(self) -> bool: + """ + Attention weights are required for KV caches following the Heavy + Hitter Oracle (H2O) proposal. + + Returns: + If `True`, :meth:`update` requires argument `attn_weights`, which + passes current attention weights as + `(eff_batch_size, n_query_groups, num, T)` tensor, where + `T <= cache_length` is the current cache length, and `num` is the + number of tokens in the last recent :meth:`forward` call. + + """ + return False + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + """ + Implements :meth:`forward` for `input_pos=0`. + Starts a generation loop by passing key and value tensors coming from + a prefill with embeddings coming from the prompts. The length must be + `T <= max_prefill_length`. The effective batch size must be + `eff_batch_size <= batch_size`. This batch size is then fixed for + subsequent calls of :meth:`forward` and :meth:`update`. + + Args: + key: Prefill keys, `(eff_batch_size, n_query_groups, T, head_size)` + value: Prefill values, `(eff_batch_size, n_query_groups, T, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, T)`. + Some KV caches make use of this information. + + """ + raise NotImplementedError() diff --git a/litgpt/kvcache/baselines.py b/litgpt/kvcache/baselines.py new file mode 100644 index 0000000000..98133c665c --- /dev/null +++ b/litgpt/kvcache/baselines.py @@ -0,0 +1,383 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from typing import Dict, Optional, Tuple + +import torch + +from litgpt.attention import DefaultKeysAndValues, KeysAndValues +from litgpt.config import Config +from litgpt.kvcache import DefaultKVCache, KVCacheParams +from litgpt.kvcache.utils import bits_for_torch_dtype, bitsize_of + + +class DenseKVCache(DefaultKVCache): + """ + Key-value cache for dense attention. Key and value tensors for all + past tokens are maintained. The cache length is the maximum sequence + length. This cache requires a lot of memory, it can only be used for + moderate cache lengths. + + Note: If the cache is full, :meth:`forward` raises an exception. The cache + buffers are allocated up front and are not enlarged later on. + + """ + + def __init__( + self, + config: Config, + batch_size: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: Optional[int] = None, + head_size: Optional[int] = None, + **base_kwargs, + ): + """ + Args: + config: Model config + batch_size: Inference batch size + device: Device for buffers + dtype: Data type for buffers + max_sequence_length: Cache length. If not given, we use + `config.block_size` + head_size: Size of final dimension of buffers. Defaults to head + size of model + + """ + if max_sequence_length is None: + max_sequence_length = config.block_size + super().__init__( + config=config, + batch_size=batch_size, + cache_length=max_sequence_length, + block_idx=block_idx, + dtype=dtype, + head_size=head_size, + **base_kwargs, + ) + shape = (batch_size, self.n_query_groups, max_sequence_length, self.head_size) + self.register_buffer("v", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + shape = shape[:-1] + (self.head_size + 1,) + self.register_buffer("k", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + self.next_position = None + self._eff_batch_size = None + + @property + def device(self) -> torch.device: + return self.k.device + + @property + def eff_batch_size(self) -> Optional[int]: + return self._eff_batch_size + + @property + def next_token_pos(self) -> Optional[int]: + return self.next_position + + @property + def max_tokens_forward(self) -> int: + return self.cache_length + + @property + def max_prefill_length(self) -> Optional[int]: + return self.cache_length + + @property + def current_length(self) -> int: + return self.next_position + + def get_keys_values(self) -> Optional[KeysAndValues]: + if self.eff_batch_size is None or self.next_position is None: + return None + else: + return DefaultKeysAndValues( + self.k[: self.eff_batch_size, :, : self.next_position, :], + self.v[: self.eff_batch_size, :, : self.next_position, :], + ) + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + if self.next_position is None: + raise IndexError("Cache needs to be initialized with 'prefill' before being used") + num = key.shape[2] + if not 1 <= num <= self.max_tokens_forward: + raise ValueError(f"key.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + np = self.next_position + if np + num > self.cache_length: + raise IndexError(f"Cache has at most {self.cache_length - np} free slots, cannot add {num} entries") + shape = (self.eff_batch_size, self.n_query_groups, num, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + if key.dtype != value.dtype: + raise ValueError(f"key.dtype = {key.dtype} != {value.dtype} = value.dtype") + # Move the buffer to the activation dtype for when AMP is used + # TODO: Is this needed? Other KV caches do not support changing + # `dtype` after creation. + if key.dtype != self.dtype: + self._dtype = key.dtype + self.k = self.k.to(self.dtype) + self.v = self.v.to(self.dtype) + # Append new content to cache + self.k[: self.eff_batch_size, :, np : (np + num), :] = key + self.v[: self.eff_batch_size, :, np : (np + num), :] = value + self.next_position += num + return self.get_keys_values() + + def _update(self, *args, **kwargs): + pass + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + if key.dim() != 4: + raise ValueError("key must have 4 dimensions") + init_length = key.shape[2] + if init_length > self.cache_length: + raise ValueError(f"key.shape[2] = {init_length}, must be at most {self.cache_length}") + eff_batch_size = key.shape[0] + if eff_batch_size > self.batch_size: + raise ValueError(f"key.shape[0] = {eff_batch_size} must be at most batch_size = {self.batch_size}") + shape = (eff_batch_size, self.n_query_groups, init_length, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Initialize cache content + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + self.k[:eff_batch_size, :, :init_length, :] = key + self.v[:eff_batch_size, :, :init_length, :] = value + self.next_position = init_length + self._eff_batch_size = eff_batch_size + + def resize(self, new_length: int): + """ + Shortens the cache content to length `current_length`, removing the + most recently inserted content. Note that this method is currently + supported only for specific KV caches; the cost for supporting it + generally would be high. + + Args: + new_length: New length, must be <= current length + + """ + if not (0 <= new_length <= self.next_position): + raise ValueError(f"current_length = {new_length}, must be in [0, {self.next_position}]") + self.next_position = new_length + + def token_positions(self) -> torch.Tensor: + return ( + torch.arange(self.next_position, device=self.device) + .reshape(1, 1, -1) + .expand(self.eff_batch_size, self.n_query_groups, -1) + ) + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + sz_buffs = bitsize_of(self.k) + bitsize_of(self.v) + return sz_buffs, dict(buffers=sz_buffs) + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + cache_length = params.cache_length + dtype = params.dtype + if dtype is None: + raise ValueError("params.dtype must be provided") + numel = params.batch_size * params.n_query_groups * cache_length * params.head_size + sz_buffs = 2 * numel * bits_for_torch_dtype(dtype) + return sz_buffs, dict(buffers=sz_buffs) + + +class LastRecentlyInsertedKVCache(DefaultKVCache): + """ + Baseline key-value cache which stores the last recently inserted + `cache_length` key, value tensors. + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + **base_kwargs, + ): + super().__init__( + config=config, + batch_size=batch_size, + cache_length=cache_length, + block_idx=block_idx, + dtype=dtype, + head_size=head_size, + **base_kwargs, + ) + shape = (batch_size, self.n_query_groups, cache_length, self.head_size) + self.register_buffer("v", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + shape = shape[:-1] + (self.head_size + 1,) + self.register_buffer("k", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("token_pos", torch.zeros(cache_length, device=device, dtype=torch.int), persistent=False) + self.next_position = None + self._eff_batch_size = None + self.current_length = None + self._next_token_pos = None + + @property + def device(self) -> torch.device: + return self.k.device + + @property + def eff_batch_size(self) -> Optional[int]: + return self._eff_batch_size + + @property + def next_token_pos(self) -> Optional[int]: + return self._next_token_pos + + @property + def max_tokens_forward(self) -> int: + return self.cache_length + + @property + def max_prefill_length(self) -> Optional[int]: + return None + + def get_keys_values(self) -> Optional[KeysAndValues]: + if self.eff_batch_size is None or self.current_length is None: + return None + else: + return DefaultKeysAndValues( + self.k[: self.eff_batch_size, :, : self.current_length, :], + self.v[: self.eff_batch_size, :, : self.current_length, :], + ) + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + if self.next_position is None: + raise IndexError("Cache needs to be initialized with 'prefill' before being used") + if key.ndim != 4: + raise ValueError(f"key must be a 4D tensor, but has shape {key.shape}") + num = key.shape[2] + if not 1 <= num <= self.max_tokens_forward: + raise ValueError(f"key.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + shape = (self.eff_batch_size, self.n_query_groups, num, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + # Append new content to cache + np = self.next_position + num1 = min(num, self.cache_length - np) + self.k[: self.eff_batch_size, :, np : (np + num1), :] = key[:, :, :num1, :] + self.v[: self.eff_batch_size, :, np : (np + num1), :] = value[:, :, :num1, :] + ntp = self._next_token_pos + self.token_pos[np : (np + num1)] = torch.arange(ntp, ntp + num1, device=self.device, dtype=torch.int) + if num1 < num: + diff = num - num1 + self.k[: self.eff_batch_size, :, :diff, :] = key[:, :, num1:, :] + self.v[: self.eff_batch_size, :, :diff, :] = value[:, :, num1:, :] + self.token_pos[:diff] = torch.arange(ntp + num1, ntp + num, device=self.device, dtype=torch.int) + self.next_position = (np + num) % self.cache_length + self.current_length = min(self.current_length + num, self.cache_length) + self._next_token_pos += num + return self.get_keys_values() + + def _update(self, *args, **kwargs): + pass + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + if key.dim() != 4: + raise ValueError("key must have 4 dimensions") + init_length = key.shape[2] + eff_init_length = min(init_length, self.cache_length) + eff_batch_size = key.shape[0] + if eff_batch_size > self.batch_size: + raise ValueError(f"key.shape[0] = {eff_batch_size} must be at most batch_size = {self.batch_size}") + shape = (eff_batch_size, self.n_query_groups, init_length, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Initialize cache content + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + self.k[:eff_batch_size, :, :eff_init_length, :] = key[:, :, -eff_init_length:, :] + self.v[:eff_batch_size, :, :eff_init_length, :] = value[:, :, -eff_init_length:, :] + self.token_pos[:eff_init_length] = torch.arange( + init_length - eff_init_length, + init_length, + dtype=self.token_pos.dtype, + device=self.token_pos.device, + ) + self.current_length = eff_init_length + self._next_token_pos = init_length + self.next_position = eff_init_length % self.cache_length + self._eff_batch_size = eff_batch_size + + def token_positions(self) -> torch.Tensor: + return ( + self.token_pos[: self.current_length].reshape(1, 1, -1).expand(self.eff_batch_size, self.n_query_groups, -1) + ) + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + sz_buffs = bitsize_of(self.k) + bitsize_of(self.v) + sz_pos = bitsize_of(self.token_pos) + return sz_buffs + sz_pos, dict(buffers=sz_buffs, token_pos=sz_pos) + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + cache_length = params.cache_length + dtype = params.dtype + if dtype is None: + raise ValueError("params.dtype must be provided") + numel = params.batch_size * params.n_query_groups * cache_length * params.head_size + k_and_v = 2 * numel * bits_for_torch_dtype(dtype) + tk_p = cache_length * bits_for_torch_dtype(torch.int) + return k_and_v + tk_p, dict(buffers=k_and_v, token_pos=tk_p) diff --git a/litgpt/kvcache/testing.py b/litgpt/kvcache/testing.py new file mode 100644 index 0000000000..9dc0989ec3 --- /dev/null +++ b/litgpt/kvcache/testing.py @@ -0,0 +1,77 @@ +from typing import Tuple + +import torch + +from litgpt.config import Config +from litgpt.kvcache.base import KVCache, KVCacheParams +from litgpt.kvcache.baselines import DenseKVCache, LastRecentlyInsertedKVCache + +KV_CACHE_NAMES = ( + "dense-default", + "lastrec-default", +) + + +def create_kv_cache( + name: str, + params: KVCacheParams, + block_idx: int = 0, +) -> KVCache: + config = Config( + n_embd=params.n_head * params.head_size, + n_head=params.n_head, + n_query_groups=params.n_query_groups, + ) + from_config_kwargs = dict( + config=config, + batch_size=params.batch_size, + block_idx=block_idx, + device=params.device, + dtype=params.dtype, + ) + + result = None + if name == "dense-default": + result = DenseKVCache(**from_config_kwargs) + elif name == "lastrec-default": + result = LastRecentlyInsertedKVCache(**from_config_kwargs, cache_length=params.cache_length) + + if result is None: + raise ValueError(f"name = {name} not supported") + return result + + +def tensor_is_simple(x: torch.Tensor) -> bool: + assert x.ndim > 1 + x = x.view(-1, x.shape[-1]) + other = x[0].unsqueeze(0).expand(*x.shape) + return x.equal(other) + + +def random_tensor( + params: KVCacheParams, + num: int, +) -> torch.Tensor: + shape = (params.batch_size, params.n_query_groups, num, params.head_size) + return torch.randn(*shape, device=params.device, dtype=params.dtype) + + +def random_keys_values( + params: KVCacheParams, + num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + keys = random_tensor(params, num) + values = random_tensor(params, num) + return keys, values + + +def random_attn_weights( + params: KVCacheParams, + num: int, +) -> torch.Tensor: + attn_weights = torch.randn( + (params.batch_size, params.n_head, num), + device=params.device, + dtype=params.dtype, + ) + return torch.nn.functional.softmax(attn_weights, dim=-1) diff --git a/litgpt/kvcache/utils.py b/litgpt/kvcache/utils.py new file mode 100644 index 0000000000..d618cd8202 --- /dev/null +++ b/litgpt/kvcache/utils.py @@ -0,0 +1,17 @@ +import torch + + +def bits_for_torch_dtype(dtype: torch.dtype) -> int: + """ + Args: + dtype: Torch data type + + Returns: + Number of bits used to represent one number of this type. + + """ + return torch.tensor([], dtype=dtype).element_size() * 8 + + +def bitsize_of(x: torch.Tensor) -> int: + return x.numel() * x.element_size() * 8 diff --git a/litgpt/lora.py b/litgpt/lora.py index 6739b5b040..fd1d21e120 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -53,7 +53,9 @@ from typing_extensions import Self import litgpt +from litgpt.attention import MultiHeadSelfAttention from litgpt.config import Config as BaseConfig +from litgpt.kvcache.base import KVCache from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention @@ -477,7 +479,7 @@ def mlp_class(self) -> Type: class GPT(BaseModel): # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, **mha_kwargs) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config @@ -496,8 +498,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config, **mha_kwargs) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_caches`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -517,15 +522,25 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.mlp = config.mlp_class(config) class CausalSelfAttention(BaseCausalSelfAttention): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = LoRAQKVLinear( @@ -549,6 +564,11 @@ def __init__(self, config: Config, block_idx: int) -> None: use_r=config.lora_projection, ) + @property + def device(self) -> Optional[torch.device]: + w = self.qkv.linear.weight + return None if w is None else w.device + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" mapping = { diff --git a/litgpt/model.py b/litgpt/model.py index c3c1833db9..b43af45f8a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -6,7 +6,6 @@ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. """ -import math from functools import partial from typing import Any, List, Optional, Tuple, Union @@ -15,17 +14,37 @@ import torch.nn.functional as F from typing_extensions import Self -from litgpt.config import Config +from litgpt.attention import ( + DefaultKeysAndValues, + MultiHeadSelfAttention, + do_softcapping, +) +from litgpt.config import Config, StartOfLayerHook +from litgpt.kvcache import ( + DenseKVCache, + KVCache, + KVCacheParams, +) from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble +from litgpt.utils import batched_index_select class GPT(nn.Module): - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, **mha_kwargs) -> None: + """ + Args: + config: Configuration parameters + + """ super().__init__() assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -33,8 +52,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None - self.max_seq_length = self.config.block_size + self.mha = MultiHeadSelfAttention(config, **mha_kwargs) + self.max_seq_length = config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_caches`? + self._default_kv_cache = False @property def max_seq_length(self) -> int: @@ -43,8 +65,13 @@ def max_seq_length(self) -> int: @max_seq_length.setter def max_seq_length(self, value: int) -> None: """ - When doing inference, the sequences used might be shorter than the model's context length. - This allows setting a smaller number to avoid allocating unused memory + When doing inference, the sequences used might be shorter than the + model's context length. This allows setting a smaller number to avoid + allocating unused memory. + + If KV caches are of type `DenseKVCache`, and they are too small to hold + `value` entries, a warning message is printed. + """ if value > self.config.block_size: raise ValueError( @@ -52,24 +79,174 @@ def max_seq_length(self, value: int) -> None: " This is likely because the input text exceeds the supported context length of this model." ) self._max_seq_length = value + # RoPE cache: + # `cos`, `sin` of shape `(max_seq_length, config.rope_n_elem)` + # More precisely, the RoPE cache is recomputed only if + # `max_seq_length` increases. + # Note: The RoPE cache is independent of KV caches, since positional + # encoding is done (on query and key vectors) before the KV cache + # gets involved (and the KV cache stores encoded key tensors). + if not hasattr(self, "cos") or self.cos.size(0) < value: + self.reset_caches() + # KV caches + # We do not change them here, but output a warning if default caches are + # too small + for l_ix, block in enumerate(self.transformer.h): + attn = block.attn + kv_cache = attn.kv_cache + if kv_cache is not None and isinstance(kv_cache, DenseKVCache) and kv_cache.cache_length < value: + print( + f"KV cache for layer {l_ix} too small: Call 'set_kv_caches(batch_size={kv_cache.batch_size}, max_seq_length={value}) before inference" + ) + break + # Multi-head attention + self.mha.set_seq_length(value, device=self.cos.device) + + def reset_caches(self): if not hasattr(self, "cos"): # first call cos, sin = self.rope_cache() self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) - # override - elif value != self.cos.size(0): + else: self.cos, self.sin = self.rope_cache(device=self.cos.device) - # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know - # if the kv cache is expected - if self.mask_cache is not None and self.mask_cache.shape[-1] < value: - print( - f"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!" - ) + + def are_kv_caches_assigned(self) -> bool: + status = [block.attn.kv_cache is not None for block in self.transformer.h] + result = any(status) + if result and not all(status): + raise IndexError("Some layers have KV caches assigned, but not all") + return result + + def assign_kv_caches(self, kv_caches: List[KVCache]): + """ + Assigns specific KV caches to the multi-head attention blocks + of each layer. This can only be done if no caches have been + assigned or created (see :meth:`set_kv_caches`) before. + + KV caches are required for inference (i.e., calling :meth:`forward` with + `input_pos` argument). If no KV caches are assigned, inference calls + fail. + + Args: + kv_caches: KV caches, one for each layer of the model + + """ + if self.are_kv_caches_assigned(): + raise ValueError("Model has KV caches assigned already") + if len(kv_caches) != self.config.n_layer: + raise ValueError(f"kv_caches must have one entry per layer, so {self.config.n_layer} entries") + batch_size = kv_caches[0].batch_size + dtype = kv_caches[0].dtype + for cache, block in zip(kv_caches, self.transformer.h): + self._check_kv_cache(self.config, cache, batch_size, dtype) + device = block.attn.device + if device is not None: + block.attn.kv_cache = cache.to(device=device) + else: + block.attn.kv_cache = cache + + def set_kv_caches( + self, + batch_size: int, + dtype: Optional[torch.dtype] = None, + max_seq_length: Optional[int] = None, + ): + """ + This method can be called only if KV caches have not been assigned + with :meth:`assign_kv_caches`. It creates default (dense) KV caches + for every layer. These may require a lot of memory. If this is an + issue, consider :meth:`assign_kv_caches` with KV caches of restricted + size. + + KV caches are required for inference (i.e., calling :meth:`forward` with + `input_pos` argument). If no KV caches are assigned, inference calls + fail. + + Args: + batch_size: Inference batch size + dtype: Data type for buffers + max_seq_length: Cache length. If not given, we use + `self.max_seq_length` + + """ + if self.are_kv_caches_assigned() and not self._default_kv_cache: + raise ValueError("Model has KV caches assigned already") + if max_seq_length is None: + max_seq_length = self.max_seq_length + for block in self.transformer.h: + attn = block.attn + device = attn.device + kv_cache = attn.kv_cache + if ( + kv_cache is None + or kv_cache.batch_size != batch_size + or kv_cache.cache_length != max_seq_length + or kv_cache.device != device + or kv_cache.dtype != dtype + ): + if kv_cache is not None: + device = kv_cache.device if device is None else device + dtype = kv_cache.dtype if dtype is None else dtype + attn.create_default_kv_cache( + batch_size=batch_size, + device=device, + dtype=dtype, + max_sequence_length=max_seq_length, + ) + self._default_kv_cache = True def reset_parameters(self) -> None: # Trigger resetting the rope-cache self.cos, self.sin = self.rope_cache(device=self.cos.device) + self.mha.set_seq_length(self.max_seq_length, device=self.cos.device) + + def set_start_of_layer_hook( + self, + hook: Optional[StartOfLayerHook], + ): + """ + Sets a function `hook(x, block_idx, input_pos)`, which is called + in :meth:`forward` at the start of each layer. Here, `x` is the + layer input, `block_idx` the number of the layer, and `input_pos` + the position in the sequence. The hook is called with the output + of the final layer (input of head model), where + `block_idx=self.config.n_layer`. + + The default start of layer hook is `self.config.start_of_layer_hook`. + This is overwritten here. + + Args: + hook: Hook function to be set, or `None` to remove hook + + """ + self._start_of_layer_hook = hook + + @staticmethod + def _check_kv_cache( + config: Config, + kv_cache: KVCache, + batch_size: int, + dtype: torch.dtype, + ): + params = kv_cache.get_params() + if config.n_query_groups != params.n_query_groups: + raise ValueError( + f"config and kv_cache not compatible: config.n_query_groups = {config.n_query_groups} != {params.n_query_groups} = kv_cache.n_query_groups" + ) + if config.n_head != params.n_head: + raise ValueError( + f"config and kv_cache not compatible: config.n_head = {config.n_head} != {params.n_head} = kv_cache.n_head" + ) + head_size = config.n_embd // config.n_head + if head_size != params.head_size: + raise ValueError( + f"config and kv_cache not compatible: config.head_size = {head_size} != {params.head_size} = kv_cache.head_size" + ) + if batch_size != params.batch_size: + raise ValueError(f"kv_cache.batch_size = {params.batch_size}, must be {batch_size}") + if dtype != params.dtype: + raise ValueError(f"kv_cache.dtype = {params.dtype}, must be {dtype}") def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`.""" @@ -83,95 +260,133 @@ def _init_weights(self, module: nn.Module) -> None: def forward( self, idx: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + input_pos: Optional[int] = None, lm_head_chunk_size: int = 0, + skip_lm_head: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ - If `input_pos` is provided, the KV cache uses K and V vectors for - positions smaller than entries in `input_pos`. For efficiency, pass - `input_pos_maxp1` as `max(input_pos) + 1` if already available from - your forward algorithm. This slices the KV cache buffers and speeds - up multi-head attention. - - Without `input_pos_maxp1`, the computation uses the full KV cache - (`max_seq_length`) with masking applied. Note that inferring - `input_pos_maxp1` from `input_pos` causes graph breaks and prevents - compilation. + There are two different contexts in which this method is called: + + - Training: `input_pos` not given. KV cache is not needed. + - Inference, `input_pos` is given. There are two cases: `input_pos=0` + (prefill) and `input_pos > 0` (generation). For prefill, KV caches + must have been assigned (:meth:`assign_kv_caches` or + :meth:`set_kv_caches`). We must have + `T <= model.kv_cache_max_prefill_length()`. + - For generation, KV caches must have been assigned + (:meth:`assign_kv_caches` or :meth:`set_kv_caches`). We check that + `input_pos == kv_cache.next_token_pos`. Note that `T > 1` is + permitted here as well. + + Note: If this method is called with `input_pos=0` (prefill) after + generation calls, a new inference sequence is started. The batch + size for the new sequence can be different. + + Token generation (`input_pos > 0`) and `T > 1`: + + This situation is non-standard, since `idx` needs to provide tokens at + positions `input_pos:(input_pos + T)`, whereas the logits are for + generating tokens at `(input_pos + 1):(input_pos + T + 1)`, so only the + last position is needed to generate a new token. Use cases: + - Updating KV caches sequentially if prompt size is larger than max + prefill length of cache + - Speculative decoding. Here, `idx` comes from the cheaper proposal + model, and the logits are needed for the accept/reject probabilities. Args: idx: Token indices of input sequences, shape `(B, T)`, where `B` is batch size. - input_pos: Optional. Positions of input tokens. The default is - `arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index). - input_pos_maxp1: Optional. See above. + input_pos: See above. Defaults to `None` lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final `lm_head` computation is done in chunks of this size. + skip_lm_head: If `True`, we do not apply the final LM head + `self.lm_head`. Returns: Logit outputs, shape `(B, T, config.padded_vocab_size)`. If `lm_head_chunk_size > 0`, this is a list of chunks of shape `(B, lm_head_chunk_size, config.padded_vocab_size)`, the final entry can be shorter. + If `skip_lm_head` is `True`, we return the final layer outputs, + shape `(B, T, config.n_embd)`. """ + if idx.ndim == 1: + idx = idx.unsqueeze(0) + elif idx.ndim != 2: + raise ValueError(f"idx must be 1D or 2D tensor, but idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") - - if input_pos is not None: # use the kv cache - if input_pos.dim() > 2: - # otherwise, things go wrong in `apply_rope` - raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") - if input_pos.shape[-1] != T: - raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") - cos = batched_index_select(self.cos, 0, input_pos) - sin = batched_index_select(self.sin, 0, input_pos) - if input_pos.dim() == 1: - cos = cos.unsqueeze(0) - sin = sin.unsqueeze(0) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = batched_index_select(self.mask_cache, 2, input_pos) - if mask.dim() > 4: - # the mask cache has a batch dim of 1 in addition to the one - # we get if input_pos has a batch dimension - mask = mask.view(*(mask.shape[0:1] + mask.shape[2:])) - if input_pos_maxp1 is not None: - # Shorten final dimension so it just covers all `input_pos` entries - if input_pos_maxp1 > self.max_seq_length: - raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})") - mask = mask[..., :input_pos_maxp1] + for_prefill = False + if input_pos is not None: + # Few tokens generation. This needs a KV cache. If none is assigned, + # the call fails + if not self.are_kv_caches_assigned(): + raise ValueError( + "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_caches'" + ) + for_prefill = input_pos == 0 + if not for_prefill: + for block_idx, block in enumerate(self.transformer.h): + kv_cache = block.attn.kv_cache + if kv_cache.next_token_pos is None: + raise ValueError("Inference calls need to start with pre-fill, i.e. 'input_pos=0'") + if kv_cache.next_token_pos != input_pos: + raise ValueError( + f"KV cache for layer {block_idx}: input_pos = {input_pos} != {kv_cache.next_token_pos} = kv_cache.next_token_pos" + ) + if kv_cache.max_tokens_forward < T: + raise ValueError( + f"KV cache for layer {block_idx}: T = {T}, must be <= max_tokens_forward = {kv_cache.max_tokens_forward}" + ) + + if self.config.rope_n_elem > 0: + input_pos_array = torch.arange(input_pos, input_pos + T, device=self.cos.device, dtype=torch.int64) + cos = batched_index_select(self.cos, 0, input_pos_array).unsqueeze(0) + sin = batched_index_select(self.sin, 0, input_pos_array).unsqueeze(0) + else: + cos = sin = None else: - # unsqueeze to have a batch dimension + # Unsqueeze to have a batch dimension cos = self.cos[:T].unsqueeze(0) sin = self.sin[:T].unsqueeze(0) - # `cos`, `sin` have shape (1, T, config.rope_n_elem) - mask = None # defaults to causal mask - input_pos_maxp1 = None + # `cos`, `sin` have shape `(1, T, config.rope_n_elem)`, or shape + # `(1, T, config.rope_n_elem, 2)` x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) + hook = self._start_of_layer_hook for block_idx, block in enumerate(self.transformer.h): + if for_prefill: + # Complain if batch size of cache is too small + eff_batch_size = x.shape[0] + attn = block.attn + if attn.kv_cache.batch_size < eff_batch_size: + raise ValueError( + f"Batch size {eff_batch_size} is too large for KV cache layer {block_idx} (batch size {attn.kv_cache.batch_size}). Use 'assign_kv_caches' or `set_kv_caches'" + ) + if hook is not None: + # Call start of layer hook, passing detached layer input + hook(x.detach(), block_idx, input_pos) if self.config.rope_indices is not None: - x = block( - x, - cos[..., self.config.rope_indices[block_idx]], - sin[..., self.config.rope_indices[block_idx]], - mask, - input_pos, - input_pos_maxp1, - ) + # Select global (0) or local (1) variant + _cos = cos[..., self.config.rope_indices[block_idx]] + _sin = sin[..., self.config.rope_indices[block_idx]] else: - x = block(x, cos, sin, mask, input_pos, input_pos_maxp1) + _cos = cos + _sin = sin + x = block(x, _cos, _sin, idx, self.mha, input_pos) + + if hook is not None: + # Hook is also called for the input to the head block + hook(x.detach(), self.config.n_layer, input_pos) x = self.transformer.ln_f(x) - clamp_head = ( - partial(do_softcapping, thresh=self.config.final_logit_softcapping) - if self.config.final_logit_softcapping is not None - else nn.Identity() - ) + if skip_lm_head: + return x + clamp_head = partial(do_softcapping, thresh=self.config.final_logit_softcapping) if lm_head_chunk_size > 0: # chunk the lm head logits to reduce the peak memory used by autograd return [clamp_head(self.lm_head(x_i)) for x_i in x.split(lm_head_chunk_size, dim=1)] @@ -182,10 +397,23 @@ def forward( def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) - def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def rope_cache( + self, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Recomputes the RoPE cache, consisting of tensors `cos`, `sin`. + + Args: + device: Device for RoPE cache tensors + + Returns: + `(cos, sin)`, each of shape `(max_seq_length, config.rope_n_elem)` + or of shape `(max_seq_length, config.rope_n_elem, 2)`. + + """ if self.config.rope_adjustments is None: extra_config = None - else: adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] @@ -220,42 +448,62 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso rope_local_base_freq=self.config.rope_local_base_freq, ) - def set_kv_cache( - self, - batch_size: int, - max_seq_length: Optional[int] = None, - rope_cache_length: Optional[int] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - if rope_cache_length is None: - if len(self.cos.shape) == 2: - rope_cache_length = self.cos.size(-1) - else: - rope_cache_length = self.cos[..., 0].size(-1) + def clear_kv_caches(self) -> None: + """ + Note that KV cache objects are removed only if they have not been + assigned with :meth:`assign_kv_caches`. - if max_seq_length is None: - max_seq_length = self.max_seq_length + """ + if self._default_kv_cache: + for block in self.transformer.h: + block.attn.kv_cache = None + self._default_kv_cache = False - # initialize the kv cache for all blocks - for block in self.transformer.h: - block.attn.kv_cache = block.attn.build_kv_cache( - batch_size, - max_seq_length, - rope_cache_length, - device, - dtype, - ) + def get_kv_cache_params(self, layer_idx: int) -> Optional[KVCacheParams]: + """ + Args: + layer_idx: Layer for which KV cache params are requested - if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: - # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask - # for the kv-cache support (only during inference), we only create it in that situation - self.mask_cache = build_mask_cache(max_seq_length, device) + Returns: + Parameters for KV caches (see above), or `None` if KV caches are + not assigned. - def clear_kv_cache(self) -> None: - self.mask_cache = None - for block in self.transformer.h: - block.attn.kv_cache = None + """ + if not (0 <= layer_idx < self.config.n_layer): + raise IndexError(f"layer_idx={layer_idx}, must be in [0, {self.config.n_layer})") + kv_cache = self.transformer.h[layer_idx].attn.kv_cache + return None if kv_cache is None else kv_cache.get_params() + + def kv_cache_max_tokens_forward(self) -> Optional[int]: + """ + Returns: + Smallest `max_tokens_forward` over all KV caches, or `None` if KV + caches are not assigned. + + """ + caches = [layer.attn.kv_cache for layer in self.transformer.h] + if any(cache is None for cache in caches): + return None + else: + return min(kvc.max_tokens_forward for kvc in caches) + + def kv_cache_max_prefill_length(self) -> Optional[int]: + """ + Returns: + Smallest `max_prefill_length` over all KV caches, or `None` if KV + caches are not assigned or if `max_prefill_length is None` for all + KV caches. + + """ + caches = [layer.attn.kv_cache for layer in self.transformer.h] + if any(cache is None for cache in caches): + return None + else: + mlps = [kvc.max_prefill_length for kvc in caches] + if all(mlp is None for mlp in mlps): + return None + else: + return min(mlp for mlp in mlps if mlp is not None) class Block(nn.Module): @@ -263,6 +511,7 @@ def __init__( self, config: Config, block_idx: int, + kv_cache: Optional[KVCache] = None, ) -> None: super().__init__() if not config.parallel_residual and config.shared_attention_norm: @@ -272,7 +521,7 @@ def __init__( ) self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps) - self.attn = CausalSelfAttention(config, block_idx) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() ) @@ -285,7 +534,6 @@ def __init__( self.post_mlp_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() ) - self.config = config def forward( @@ -293,9 +541,9 @@ def forward( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + token_idx: torch.Tensor, + mha: MultiHeadSelfAttention, + input_pos: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -319,7 +567,14 @@ def forward( """ x_normed = self.norm_1(x) - attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1) + attention_output = self.attn( + x_normed, + cos=cos, + sin=sin, + token_idx=token_idx, + mha=mha, + input_pos=input_pos, + ) attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: @@ -334,7 +589,12 @@ def forward( class CausalSelfAttention(nn.Module): - def __init__(self, config: Config, block_idx: int) -> None: + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: super().__init__() # key, query and value projections for all heads, but in a batch self.qkv = nn.Linear( @@ -344,11 +604,8 @@ def __init__(self, config: Config, block_idx: int) -> None: ) # output projection self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) - # disabled by default - self.kv_cache: Optional[KVCache] = None - self.apply_sliding_window_attention = False - if config.sliding_window_size is not None and config.sliding_window_indices is not None: - self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] + # KV cache (needed for inference) + self.kv_cache = kv_cache if config.norm_qk: norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size @@ -363,15 +620,32 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config self.block_idx = block_idx + @property + def device(self) -> Optional[torch.device]: + w = self.qkv.weight + return None if w is None else w.device + def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + token_idx: torch.Tensor, + mha: MultiHeadSelfAttention, + input_pos: Optional[int] = None, ) -> torch.Tensor: + """ + Args: + x: Input tensor + cos: RoPE parameters + sin: RoPE parameters + token_idx: Token indexes corresponding to `x` + mha: Multi-head self-attention code + input_pos: See :meth:`GPT.forward` + + Returns: + Output tensor + """ # Notation: # - B | batch size # - T | time-step (sequence length) @@ -402,6 +676,23 @@ def forward( n_query_groups = self.config.n_query_groups rope_n_elem = self.config.rope_n_elem B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + if input_pos is not None: + for_prefill = input_pos == 0 + if self.kv_cache is None: + raise ValueError( + "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_caches'" + ) + if not for_prefill: + if self.kv_cache.next_token_pos is None: + raise ValueError("Inference calls need to start with pre-fill, i.e. 'input_pos=0'") + if self.kv_cache.next_token_pos != input_pos: + raise ValueError( + f"KV cache: input_pos = {input_pos} != {self.kv_cache.next_token_pos} = kv_cache.next_token_pos" + ) + if self.kv_cache.max_tokens_forward < T: + raise ValueError( + f"KV cache: T = {T}, must be <= max_tokens_forward = {self.kv_cache.max_tokens_forward}" + ) # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` # instead of individually multiplying the input `x` with the respective weight matrices. @@ -424,118 +715,75 @@ def forward( # The original GQA paper is followed here and the term query groups is used. # alternative notation: Query groups are also referred to as KV groups. q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) - k = k.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs) - v = v.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs) + k = k.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) + v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are - # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # multiple heads (nh_q), and within each head, there is a sequence of elements (T), each represented by a vector # of size `hs`. + # Note that `nh_k` can be smaller than `nh_q` (but the latter must be a + # multiple of the former). This works with the + # `scaled_dot_product_attention` implementations below. q = q.transpose(1, 2) # (B, nh_q, T, hs) k = k.transpose(1, 2) # (B, nh_k, T, hs) - v = v.transpose(1, 2) # (B, nh_v, T, hs) + v = v.transpose(1, 2) # (B, nh_k, T, hs) if self.config.norm_qk and self.config.norm_qk_type == "default": q = self.norm_q(q) k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) - k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) + if rope_n_elem > 0: + q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) + + # Inner part of multi-head self-attention computation + if input_pos is None: + # Default causal self-attention + y, _ = mha( + query=q, + k_and_v=DefaultKeysAndValues(k, v), + block_idx=self.block_idx, + ) + else: + # Defer this to KV cache + y = self.kv_cache( + query=q, + key=k, + value=v, + token_idx=token_idx, + input_pos=input_pos, + ) - # Apply kv-cache during inference. - if input_pos is not None: - if not isinstance(self.kv_cache, KVCache): - raise TypeError("You need to call `gpt.set_kv_cache()`") - k, v = self.kv_cache(input_pos, k, v) - if input_pos_maxp1 is not None: - # Subselect along sequence dimension - k = k[..., :input_pos_maxp1, :] - v = v[..., :input_pos_maxp1, :] - # k, v: (B, nh_k, input_pos_maxp1, hs) - # If input_pos_maxp1 is None -> max_seq_length - - # Grouped queries: balance the number of heads across all three matrices. - # NOTE: flash attention requires it in training mode. - # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. - if n_query_groups != n_head and (input_pos is None or n_query_groups != 1): - q_per_kv = n_head // n_query_groups - k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) - v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) - - if self.apply_sliding_window_attention: - """ - Global Window Sliding window Sliding window - attention mask + bias = attention mask - ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ - │ True False False False │ │ True True True True │ │ True False False False │ - │ True True False False │ │ True True True True │ │ True True False False │ - │ True True True False │ │ False True True True │ │ False True True False │ - │ True True True True │ │ False False True True │ │ False False True True │ - └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ - """ - if mask is None: - mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) - mask.masked_fill_(mask.bool(), float("-inf")) - mask = mask.view(1, 1, *mask.shape) - sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size) - sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) - mask += sliding_window_bias - - # Efficient attention using Flash Attention CUDA kernels. - # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. - # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) - y = self.scaled_dot_product_attention(q, k, v, mask) - - # Re-assemble all head outputs side by side. - y = y.reshape(B, T, head_size * n_head) - - # Output projection. + # Output projection + y = self._transform_output(y, query=q, mha=mha) return self.proj(y) # (B, T, C) - def scaled_dot_product_attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + def _transform_output( + self, + y: torch.Tensor, + query: torch.Tensor, + mha: MultiHeadSelfAttention, ) -> torch.Tensor: - scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) - - # with softcapping we cannot use SDPA - if self.config.attention_logit_softcapping is not None: - scores = q @ k.mT * scale - scores = do_softcapping(scores, self.config.attention_logit_softcapping) - if mask is None: - mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) - mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) - scores = scores + mask - scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) - y = scores @ v - else: - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None - ) - return y.transpose(1, 2) + return y - def build_kv_cache( + def create_default_kv_cache( self, batch_size: int, - max_seq_length: int, - rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> "KVCache": - v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) - if rope_cache_length is None: - if self.config.rotary_percentage != 1.0: - raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") - k_shape = v_shape - else: - k_shape = ( - batch_size, - self.config.n_query_groups, - max_seq_length, - rope_cache_length + self.config.head_size - self.config.rope_n_elem, - ) - return KVCache(k_shape, v_shape, device=device, dtype=dtype) + max_sequence_length: Optional[int] = None, + ): + self.kv_cache = DenseKVCache( + config=self.config, + batch_size=batch_size, + block_idx=self.block_idx, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" @@ -633,11 +881,17 @@ def build_rope_cache( device (torch.device, optional): Device for tensor allocations. base (int, optional): Base for computing inverse frequencies. condense_ratio (int, optional): Ratio to condense the position indices. - extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2) + extra_config (dict, optional): Configuration parameters for + frequency adjustments (used by Llama 3.1 and 3.2) + rope_local_base_freq: If given, this is an alternative value for + `base`. In this case, the returned tensors have an extra dimension. Returns: Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. - Shapes are `(seq_len, n_elem)`. + Shapes are `(seq_len, n_elem)` if `rope_local_base_freq` is not + given, otherwise `(seq_len, n_elem, 2)`, so that `[..., 0]` is for + `base`, and `[..., 1]` for `rope_local_base_freq`. + """ # Compute the inverse frequencies theta @@ -668,11 +922,13 @@ def build_rope_cache( idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) # If `n_elem` is odd, the final dimension of `idx_theta` has size # `n_elem + 1`, so need to cut something off. + # Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave # `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to # broadcasting. If we shorten `idx_theta`, unit tests comparing to # Hugging Face fail. # https://github.com/huggingface/transformers/issues/35233 + # TODO: Remove `> 1` once HF bug is fixed! if idx_theta.shape[-1] > n_elem > 1: idx_theta = idx_theta[..., :n_elem] @@ -682,83 +938,14 @@ def build_rope_cache( local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) local_idx_theta = torch.outer(seq_idx, local_theta) local_idx_theta = local_idx_theta.repeat(1, 2) + # TODO: Remove `> 1` once HF bug is fixed! if local_idx_theta.shape[-1] > n_elem > 1: local_idx_theta = local_idx_theta[..., :n_elem] - idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1) return torch.cos(idx_theta), torch.sin(idx_theta) -def batched_index_select(t, dim, idx): - """index_select for batched index and unbatched t""" - if idx.dim() == 1: - return torch.index_select(t, dim, idx) - - *batch_shape, idx_size = idx.shape - res = torch.index_select(t, dim, idx.reshape(-1)) # flat index - # split out single batch idx - res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - if dim > 0: - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) - # unflatten batch dims - res = res.view(*batch_shape, *res.shape[1:]) - return res - - -def batched_index_copy_(t, dim, idx, val): - """Index copy for batched t, idx, val""" - - if t.device.type == "mps": - # Normalize negative dimensions - if dim < 0: - dim = t.dim() + dim - if idx.dim() == 1: - idx_shape = [1] * val.dim() - idx_shape[dim] = -1 - idx_expanded = idx.view(*idx_shape) - idx_expanded = idx_expanded.expand_as(val) - t.scatter_(dim, idx_expanded, val) - return t - - elif idx.dim() == 2: - assert dim != 0, "Cannot index the batch dimension" - batch_size = idx.size(0) - idx_size = idx.size(1) - assert batch_size == t.size(0) == val.size(0) - - idx_shape = [batch_size] + [1] * (val.dim() - 1) - idx_shape[dim] = idx_size - idx_expanded = idx.view(*idx_shape) - idx_expanded = idx_expanded.expand_as(val) - - t.scatter_(dim, idx_expanded, val) - return t - else: - raise NotImplementedError(f"idx.dim() == {idx.dim()} not supported") - - else: - if idx.dim() == 1: - return t.index_copy_(dim, idx, val) - - assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}" - assert dim != 0, f"cannot index batch dim {dim=}" - batch_size, idx_size = idx.shape - assert batch_size == t.size(0) - assert batch_size == val.size(0) - - # if we can view the batch and indexed dimensions together, we could - # do index trickery. This is, sadly, not the case for kvcache so we - # fall back to for loop - for i in range(batch_size): - unbatched_dim = dim if dim < 0 else dim - 1 - t[i].index_copy_(unbatched_dim, idx[i], val[i]) - return t - - def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """ Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch @@ -772,7 +959,7 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T Returns: Encoded tensor, `(B, ..., T, head_size)` """ - if cos.dim() != 3: + if cos.ndim != 3: raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") if cos.shape != sin.shape: raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") @@ -780,7 +967,7 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T x1 = x[..., :head_size_half] # (B, ..., T, head_size/2) x2 = x[..., head_size_half:] # (B, ..., T, head_size/2) rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size) - dims_diff = x.dim() - cos.dim() + dims_diff = x.ndim - cos.ndim if dims_diff > 0: # Ensure that shapes of `x`, `cos`, `sin` align new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] @@ -791,64 +978,6 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) -def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: - return torch.tanh(x / thresh) * thresh - - -class KVCache(nn.Module): - """ - Buffers `k`, `v` have shape - `(batch_size, n_query_groups, max_seq_length, head_size)`. - """ - - def __init__( - self, - k_shape: Tuple[int, int, int, int], - v_shape: Tuple[int, int, int, int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__() - self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) - self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) - - def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Writes new values `k` and `v` into the cache at the positions specified - by `input_pos` along the sequence dimension (`max_seq_length`). The batch - size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch - size. Returns the full buffers, adjusted to the batch size `bs`. - - Args: - input_pos: Position index, `(bs, T)` or `(T,)` - k: New values, `(bs, n_query_groups, T, head_size)` - v: New values, `(bs, n_query_groups, T, head_size)` - - Returns: - k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)` - - """ - # move the buffer to the activation dtype for when AMP is used - if self.k.dtype != k.dtype: - self.k = self.k.to(k.dtype) - if self.v.dtype != v.dtype: - self.v = self.v.to(v.dtype) - # update the cache - bs = k.size(0) - k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) - v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) - return k, v - - def reset_parameters(self) -> None: - torch.nn.init.zeros_(self.k) - torch.nn.init.zeros_(self.v) - - -def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: - ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) - return torch.tril(ones).unsqueeze(0).unsqueeze(0) - - class RMSNorm(torch.nn.Module): """Root Mean Square Layer Normalization. diff --git a/litgpt/utils.py b/litgpt/utils.py index af97fa2f11..ba76d4d310 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -873,3 +873,22 @@ def kill_process_tree(pid: int): parent.kill() except psutil.NoSuchProcess: pass # Process already exited + + +def batched_index_select(t: torch.Tensor, dim: int, idx: torch.Tensor) -> torch.Tensor: + """index_select for batched index and unbatched t""" + if idx.ndim == 1: + return torch.index_select(t, dim, idx) + + *batch_shape, idx_size = idx.shape + res = torch.index_select(t, dim, idx.reshape(-1)) # flat index + # split out single batch idx + res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.ndim)) + del dims[dim + 1] + res = res.permute(dims) + # unflatten batch dims + res = res.view(*batch_shape, *res.shape[1:]) + return res diff --git a/tests/generate/test_adapter.py b/tests/generate/test_adapter.py index 782d3d435c..35ecd64c17 100644 --- a/tests/generate/test_adapter.py +++ b/tests/generate/test_adapter.py @@ -47,10 +47,17 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) - assert ( - generate_mock.mock_calls - == [call(ANY, tensor_like, 101, temperature=2.0, top_k=2, top_p=0.9, eos_id=ANY)] * num_samples + expected_call = call( + model=ANY, + prompt=tensor_like, + prompt_chunksize=16, + max_returned_tokens=101, + temperature=2.0, + top_k=2, + top_p=0.9, + eos_id=ANY, ) + assert generate_mock.mock_calls == [expected_call] * num_samples expected_output = "foo bar baz\n" * num_samples # Allow for the config to be printed before the expected repeated strings. diff --git a/tests/generate/test_main.py b/tests/generate/test_main.py index fd430318b0..1bc2817146 100644 --- a/tests/generate/test_main.py +++ b/tests/generate/test_main.py @@ -15,7 +15,7 @@ import litgpt.generate.base as generate from litgpt import GPT, Config -from litgpt.generate.base import sample +from litgpt.generate.base import batched_sample skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", @@ -23,10 +23,7 @@ ) -@pytest.mark.parametrize( - "max_seq_length", (pytest.param(10, marks=pytest.mark.xfail(raises=NotImplementedError, strict=True)), 20 + 5) -) -def test_generate(max_seq_length): +def test_generate(): import lightning as L L.seed_everything(1234) @@ -34,25 +31,41 @@ def test_generate(max_seq_length): T = 5 input_idx = torch.arange(0, T) - config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) + config = Config( + block_size=128, + vocab_size=16, + n_layer=1, + n_head=4, + n_embd=8, + ) model = GPT(config) - model.max_seq_length = max_seq_length - model.set_kv_cache(batch_size=1) max_new_tokens = 20 + model.max_seq_length = T + max_new_tokens + model.set_kv_caches(batch_size=1) multinomial_results = [] def multinomial(*args, **kwargs): - out = torch.multinomial(*args, **kwargs, num_samples=1) + if args: + probs = args[0] + else: + probs = kwargs.get("probs") + out = torch.multinomial(probs, num_samples=1) multinomial_results.append(out) return out with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial): - out = generate.generate(model, input_idx, T + max_new_tokens, top_k=1) + out = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=T + max_new_tokens, + top_k=1, + ) assert out.size(0) == T + max_new_tokens, (out.size(0), T + max_new_tokens) multinomial_results = torch.hstack(multinomial_results) - expected = torch.cat((input_idx, multinomial_results)) + print(f"input_idx {input_idx.shape}, multinomial_results: {multinomial_results.shape}") + expected = torch.cat((input_idx, multinomial_results.squeeze(0))) assert out.shape == expected.shape, (out.shape, expected.shape) torch.testing.assert_close(out, expected) @@ -60,11 +73,18 @@ def multinomial(*args, **kwargs): @skip_in_ci_on_macos def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): config_path = fake_checkpoint_dir / "model_config.yaml" - config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1} + config = { + "block_size": 128, + "vocab_size": 50, + "n_layer": 2, + "n_head": 4, + "n_embd": 8, + "rotary_percentage": 1, + } config_path.write_text(yaml.dump(config)) module_mock = Mock() - module_mock.config.block_size = 128 + module_mock.config.block_size = config["block_size"] load_mock = Mock() load_mock.return_value = load_mock monkeypatch.setattr(generate, "load_checkpoint", load_mock) @@ -73,21 +93,42 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): tokenizer_mock.return_value.decode.return_value = "foo bar baz" monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() - generate_mock.return_value = torch.tensor([3, 2, 1]) + # fmt: off + generate_mock.return_value = torch.tensor([ + 1, 2, 3, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, + ]) + # fmt: on + len_return_value = generate_mock.return_value.numel() monkeypatch.setattr(generate, "generate", generate_mock) num_samples = 2 out, err = StringIO(), StringIO() + sample_kwargs = dict( + temperature=2.0, + top_k=2, + top_p=0.9, + ) with redirect_stdout(out), redirect_stderr(err): - generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir) + generate.main( + **sample_kwargs, + num_samples=num_samples, + checkpoint_dir=fake_checkpoint_dir, + ) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples - assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) - assert ( - generate_mock.mock_calls - == [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)] - * num_samples + assert torch.allclose( + tokenizer_mock.return_value.decode.call_args[0][0].to(torch.device("cpu")), generate_mock.return_value + ) + expected_call = call( + model=ANY, + prompt=tensor_like, + prompt_chunksize=16, + max_returned_tokens=len_return_value, + **sample_kwargs, + eos_id=tokenizer_mock.return_value.eos_id, ) + assert generate_mock.mock_calls == [expected_call] * num_samples expected_output = "foo bar baz\n" * num_samples # Allow for the config to be printed before the expected repeated strings. pattern = rf".*^{re.escape(expected_output.strip())}$.*" @@ -119,25 +160,44 @@ def test_sample(temperature): ], dtype=torch.float32, ) - token = sample(logits, temperature=temperature, top_p=0.8) + # Note: Both `sample` and `batched_sample` create only 1 sample, not 3. + # It is like passing `logits[:, 1-:, :]` + token = batched_sample(logits, kwargs=dict(temperature=temperature, top_p=0.8)) - assert token.shape == (1,) + assert token.shape == (2, 1) # sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported - assert token.tolist() == [0] + assert token[0, -1].item() == 0 def test_generate_different_results_with_different_top_p(): - config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) + config = Config( + block_size=128, + vocab_size=16, + n_layer=1, + n_head=4, + n_embd=8, + rotary_percentage=1, + ) model = GPT(config) model.max_seq_length = 50 - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) torch.manual_seed(123) input_idx = torch.randint(10, size=(1,)) torch.manual_seed(123) - output1 = generate.generate(model, input_idx, 20, top_p=1.0) + output1 = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=20, + top_p=1.0, + ) torch.manual_seed(123) - output2 = generate.generate(model, input_idx, 20, top_p=0.1) + output2 = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=20, + top_p=0.1, + ) assert not torch.equal(output1, output2) diff --git a/tests/generate/test_sequentially.py b/tests/generate/test_sequentially.py index 37175fa489..ab35b6e114 100644 --- a/tests/generate/test_sequentially.py +++ b/tests/generate/test_sequentially.py @@ -1,7 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import itertools -import math import subprocess import sys from dataclasses import asdict @@ -14,7 +13,12 @@ from lightning import Fabric from litgpt import Config -from litgpt.generate.sequentially import layer_to_device, replace_device, sequential +from litgpt.generate.sequentially import ( + chunk_sizes, + layer_to_device, + replace_device, + sequential, +) from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf @@ -28,8 +32,8 @@ (6, 1, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}), (6, 2, {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1}), (6, 3, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), - (6, 4, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), - (6, 5, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), + (6, 4, {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 3}), + (6, 5, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}), (6, 6, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}), ], ) @@ -37,8 +41,11 @@ def test_layer_to_device(n_layer, devices, expected): with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=n_layer) - max_layers_per_device = math.ceil(n_layer / devices) - actual = layer_to_device(model, Block, chunk_size=max_layers_per_device) + actual = layer_to_device( + model, + Block, + chunk_sizes=chunk_sizes(n_layer, devices), + ) expected = {f"transformer.h.{i}": v for i, v in expected.items()} assert actual == expected @@ -51,13 +58,6 @@ def test_sequential_layer_to_device_mapping_not_possible(): with pytest.raises(ValueError, match="number of layers in the model must be larger than the number of devices"): sequential(model, root=torch.device("cpu"), max_seq_length=128, devices=2) - # Last device would get 0 layers - config = Config(n_layer=6) - with torch.device("meta"): - model = GPT(config) - with pytest.raises(RuntimeError, match="Not able to distribute the 6 layers across 4 devices"): - sequential(model, root=torch.device("cpu"), max_seq_length=128, devices=4) - def path_to_device(model): return {k: str(v.device) for k, v in itertools.chain(model.named_parameters(), model.named_buffers())} @@ -111,6 +111,7 @@ def _test_model_1device(accelerator): fabric = Fabric(accelerator=accelerator, devices=1) with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=2) + model.set_kv_caches(1) model = sequential(model, fabric.device, 15, 1) device_str = str(fabric.device) @@ -250,18 +251,6 @@ def test_model_forward_hooks(): "transformer.ln_f.bias": "cuda:0", "cos": "cuda:0", "sin": "cuda:0", - "transformer.h.0.attn.kv_cache.k": "cuda:0", - "transformer.h.0.attn.kv_cache.v": "cuda:0", - "transformer.h.1.attn.kv_cache.k": "cuda:0", - "transformer.h.1.attn.kv_cache.v": "cuda:0", - "transformer.h.2.attn.kv_cache.k": "cuda:0", - "transformer.h.2.attn.kv_cache.v": "cuda:0", - "transformer.h.3.attn.kv_cache.k": "cuda:1", - "transformer.h.3.attn.kv_cache.v": "cuda:1", - "transformer.h.4.attn.kv_cache.k": "cuda:1", - "transformer.h.4.attn.kv_cache.v": "cuda:1", - "transformer.h.5.attn.kv_cache.k": "cuda:1", - "transformer.h.5.attn.kv_cache.v": "cuda:1", } assert hooks == { "transformer.h.3": [("forward_pre_hook", "move_block_input", (torch.device(type="cuda", index=1),), {})], diff --git a/tests/kvcache/test_base.py b/tests/kvcache/test_base.py new file mode 100644 index 0000000000..297343d9a4 --- /dev/null +++ b/tests/kvcache/test_base.py @@ -0,0 +1,67 @@ +import random + +import torch + +from litgpt.kvcache.base import KVCacheParams +from litgpt.kvcache.testing import ( + create_kv_cache, + random_keys_values, + random_tensor, + tensor_is_simple, +) + + +def test_most_recent(): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + + params = KVCacheParams( + batch_size=3, + n_query_groups=4, + cache_length=32, + head_size=8, + n_head=4, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache("lastrec-default", params) + num_insert = random.randint(cache_length, 3 * cache_length) + max_prefill_length = kv_cache.max_prefill_length + num_prefill = random.randint(num_insert // 3, int(num_insert * 0.75)) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + + keys, values = random_keys_values(params, num=num_insert) + queries = random_tensor(params, num=num_insert) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, num_insert), + ) + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, num_insert): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + + current_length = min(cache_length, num_insert) + assert kv_cache.current_length == current_length + token_positions = kv_cache.token_positions().to(dtype=torch.int64) + assert token_positions.shape == (params.batch_size, params.n_query_groups, current_length) + assert tensor_is_simple(token_positions) + positions = token_positions[0, 0, :].tolist() + assert len(set(positions)) == current_length + assert all(num_insert - current_length <= x < num_insert for x in positions) diff --git a/tests/kvcache/test_generic.py b/tests/kvcache/test_generic.py new file mode 100644 index 0000000000..fa7393cc2b --- /dev/null +++ b/tests/kvcache/test_generic.py @@ -0,0 +1,152 @@ +import random + +import pytest +import torch + +from litgpt.kvcache.base import KVCacheParams +from litgpt.kvcache.testing import ( + KV_CACHE_NAMES, + create_kv_cache, + random_keys_values, + random_tensor, + tensor_is_simple, +) + + +@pytest.mark.parametrize("name", KV_CACHE_NAMES) +def test_store_retrieve(name): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + + params = KVCacheParams( + batch_size=3, + n_query_groups=4, + cache_length=32, + head_size=8, + n_head=4, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache(name, params) + if name.startswith("dense"): + num_insert = random.randint(cache_length // 2, cache_length) + else: + num_insert = random.randint(cache_length, 3 * cache_length) + max_prefill_length = kv_cache.max_prefill_length + num_prefill = random.randint(num_insert // 3, int(num_insert * 0.75)) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + + keys, values = random_keys_values(params, num=num_insert) + queries = random_tensor(params, num=num_insert) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, num_insert), + ) + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, num_insert): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + + current_length = min(cache_length, num_insert) + assert kv_cache.current_length == current_length + token_positions = kv_cache.token_positions().to(dtype=torch.int64) + assert token_positions.shape == (params.batch_size, params.n_query_groups, current_length) + assert tensor_is_simple(token_positions) + # Positions for every (b, h) must be different + for b, h in zip(range(params.batch_size), range(params.n_query_groups)): + token_pos = token_positions[b, h, :].tolist() + assert all(0 <= x < num_insert for x in token_pos) + err_msg = f"num_insert = {num_insert}, b = {b}, h = {h}, current_length = {current_length}, num_prefill = {num_prefill}" + assert len(set(token_pos)) == current_length, err_msg + # Test cache content slice by slice + keys_and_values = kv_cache.get_keys_values() + for pos in range(current_length): + index = token_positions[:, :, pos][:, :, None, None].expand(-1, -1, 1, params.head_size) + # `index[i, j, 0, k] = next_position[i, j]` + k_expected = keys.gather(-2, index).squeeze(-2) + v_expected = values.gather(-2, index).squeeze(-2) + torch.testing.assert_close(k_expected, keys_and_values.keys()[:, :, pos, :]) + torch.testing.assert_close(v_expected, keys_and_values.values()[:, :, pos, :]) + + +@pytest.mark.parametrize("name", KV_CACHE_NAMES) +def test_prefill(name): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + num_compares = 3 + + params = KVCacheParams( + batch_size=2, + n_query_groups=2, + cache_length=32, + head_size=64, + n_head=2, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache(name, params) + + keys, values = random_keys_values(params, num=cache_length) + queries = random_tensor(params, num=cache_length) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, cache_length), + ) + keys_cached = [] + values_cached = [] + max_prefill_length = kv_cache.max_prefill_length + for _ in range(num_compares): + num_prefill = random.randint(cache_length // 8, cache_length) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, cache_length): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + keys_and_values = kv_cache.get_keys_values() + if keys_and_values is not None: + keys_cached.append(keys_and_values.keys().clone()) + values_cached.append(keys_and_values.values().clone()) + else: + keys_cached.append(None) + values_cached.append(None) + + num_none = 0 + for k, v in zip(keys_cached[1:], values_cached[1:]): + if k is not None: + torch.testing.assert_close(k, keys_cached[0]) + torch.testing.assert_close(v, values_cached[0]) + else: + num_none += 1 + assert num_none < num_compares - 1 diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 33f628eda2..f72e13b38e 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -131,9 +131,8 @@ def test_adapter_compile(): assert explanation.graph_break_count == 0 model = GPT(model.config) - model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + model.set_kv_caches(2) + explanation = torch._dynamo.explain(model)(x, input_pos=0) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 1e9837fc53..12c76a6d88 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -160,9 +160,8 @@ def test_adapter_v2_compile(): assert explanation.graph_break_count == 0 model = AdapterV2GPT(model.config) - model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + model.set_kv_caches(2) + explanation = torch._dynamo.explain(model)(x, input_pos=0) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000000..0361614e7a --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,582 @@ +import math +import random +from typing import Optional, Tuple + +import pytest +import torch +from torch.nn import functional as F + +from litgpt.attention import ( + DefaultKeysAndValues, + MultiHeadSelfAttention, + do_softcapping, + scaled_dot_product_attention, +) +from litgpt.attention_utils import ( + build_mask_cache, + build_mask_slice, +) +from litgpt.config import Config +from litgpt.kvcache import KVCache +from litgpt.model import ( + GPT, + CausalSelfAttention, + apply_rope, + build_rope_cache, +) +from litgpt.utils import batched_index_select + + +@pytest.mark.parametrize( + ("n_head", "n_query_groups"), + ( + (2, 1), + (4, 1), + (8, 4), + (12, 4), + (24, 8), + (9, 3), + ), +) +@torch.inference_mode() +def test_scaled_dot_product_attention(n_head, n_query_groups): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + num_repeats = 32 + dtype = torch.bfloat16 + mask_kwargs = dict(dtype=dtype, device=torch.device("cpu")) + assert_kwargs = dict(atol=0.0005, rtol=0.05) + + for repeat in range(num_repeats): + head_size = 2 ** random.randint(3, 6) + batch_size = random.randint(1, 5) + len_key = random.randint(16, 128) + mask = None + if repeat % 2 == 0: + len_query = len_key + mask = build_mask_cache( + max_seq_length=len_key, + sliding_window_size=None, + **mask_kwargs, + ) + elif repeat % 4 == 1: + len_query = random.randint(1, len_key // 2) + mask = build_mask_slice( + input_pos=len_key - len_query, + num=len_query, + token_positions=torch.arange( + 0, + len_key, + dtype=torch.int64, + ) + .view(1, 1, -1) + .expand(batch_size, n_query_groups, -1), + n_head=n_head, + **mask_kwargs, + ) + else: + len_query = 1 + shape = (batch_size, n_head, len_query, head_size) + query = torch.randn(shape, dtype=dtype) + shape = (batch_size, n_query_groups, len_key, head_size) + key = torch.randn(shape, dtype=dtype) + value = torch.randn(shape, dtype=dtype) + k_and_v = DefaultKeysAndValues(key, value) + scale = 1.0 / math.sqrt(head_size) + + result, scores = scaled_dot_product_attention( + query, + k_and_v, + scale=scale, + mask=mask, + ) + q_per_kv = n_head // n_query_groups + key_bc = key.repeat_interleave(q_per_kv, dim=1) + value_bc = value.repeat_interleave(q_per_kv, dim=1) + k_and_v_bc = DefaultKeysAndValues(key_bc, value_bc) + result_cmp, scores_cmp = scaled_dot_product_attention( + query, + k_and_v_bc, + scale=scale, + mask=mask, + ) + msg = ( + f"bs={batch_size}, hs={head_size}, nh_q={n_head}, nh_k={n_query_groups}, len_q={len_query}, len_k={len_key}" + ) + torch.testing.assert_close(result, result_cmp, **assert_kwargs), msg + torch.testing.assert_close(scores, scores_cmp, **assert_kwargs), msg + + +@pytest.mark.parametrize( + ("sliding_window_size", "batch_size", "n_query_groups"), + ( + (None, 1, 1), + (None, 4, 16), + (4, 1, 1), + (4, 2, 32), + (128, 1, 1), + (128, 4, 16), + ), +) +@torch.inference_mode() +def test_build_mask_slice( + sliding_window_size: Optional[int], + batch_size: int, + n_query_groups: int, +): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + num_repeats = 30 + dtype = torch.bfloat16 + device = torch.device("cpu") + + for _ in range(num_repeats): + seq_len = random.randint(16, 256) + full_mask = build_mask_cache(seq_len, sliding_window_size, device, dtype) + input_pos = random.randint(1, seq_len - 1) + num = random.randint(1, min(16, seq_len - input_pos)) + cache_length = random.randint(8, seq_len - 4) + token_positions = torch.zeros( + (batch_size, n_query_groups, cache_length), + dtype=torch.int64, + device=device, + ) + for bs in range(batch_size): + for nq in range(n_query_groups): + token_positions[bs, nq, :] = torch.randperm( + seq_len, + device=device, + )[:cache_length] + mask = build_mask_slice( + input_pos=input_pos, + num=num, + token_positions=token_positions, + n_head=n_query_groups, + dtype=dtype, + device=device, + sliding_window_size=sliding_window_size, + ) + mask_cmp = batched_index_select( + full_mask[input_pos : (input_pos + num), :], + dim=1, + idx=token_positions, + ) + torch.testing.assert_close(mask, mask_cmp) + + +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +def test_mask_sliding_window(dtype): + """ + Compares `mask` used in MHA in training mode in old code (using + `mask_cache`) and new code, using a setup from + :func:`test_against_original_gemma_2` above. + + """ + device = torch.device("cpu") + T = 20 + model_name = "gemma-2-27b" + config = Config.from_name( + model_name, + block_size=T, + sliding_window_size=T // 2, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + rotary_percentage=1.0, + ) + # Determine mask used in forward call for length `T` input (old code) + # neg_infty = float("-inf") + neg_infty = torch.finfo(dtype).min + old_mask = torch.ones(T, T, dtype=dtype, device=device).triu(diagonal=1) + old_mask.masked_fill_(old_mask.bool(), neg_infty) + old_mask = old_mask.view(1, 1, *old_mask.shape) + sliding_window_bias = torch.ones_like(old_mask).tril(diagonal=-config.sliding_window_size) + sliding_window_bias.masked_fill_(sliding_window_bias.bool(), neg_infty) + old_mask += sliding_window_bias + # Determine mask as in new code + new_mask = build_mask_cache( + max_seq_length=T, + sliding_window_size=config.sliding_window_size, + dtype=dtype, + device=device, + ).view(1, 1, T, T) + torch.testing.assert_close(old_mask, new_mask) + + +# Old code before `attention.py` was factored out +class CausalSelfAttention_OLD(torch.nn.Module): + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__() + # key, query and value projections for all heads, but in a batch + self.qkv = torch.nn.Linear( + config.n_embd, + (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries + bias=config.bias or config.attn_bias, + ) + # output projection + self.proj = torch.nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + # disabled by default + self.kv_cache: Optional[KVCache] = None + self.apply_sliding_window_attention = False + if config.sliding_window_size is not None and config.sliding_window_indices is not None: + self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] + + if config.norm_qk: + self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps) + self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps) + else: + self.norm_q = self.norm_k = None + + self.config = config + self.block_idx = block_idx + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, + ) -> torch.Tensor: + # Notation: + # - B | batch size + # - T | time-step (sequence length) + # - C | model's embeddings size (n_embd) + # - C* | attentions's embeddings size + # - nh_(q,k,v) | number of heads for query, key and value + # - hs | head size + head_size = self.config.head_size + n_head = self.config.n_head + n_query_groups = self.config.n_query_groups + rope_n_elem = self.config.rope_n_elem + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` + # instead of individually multiplying the input `x` with the respective weight matrices. + qkv = self.qkv(x) # (B, T, 3xC*) + + # Define query, key and value sizes. + # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). + query_size = n_head * head_size + key_size = value_size = n_query_groups * head_size + # Split qkv into query, key and value matrices. + q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the + # embedding size (C) into num_heads (nh) and head_size (hs). + q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) + k = k.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) + v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_v, hs) + + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are + # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # of size `hs`. + q = q.transpose(1, 2) # (B, nh_q, T, hs) + k = k.transpose(1, 2) # (B, nh_k, T, hs) + v = v.transpose(1, 2) # (B, nh_v, T, hs) + + if self.config.norm_qk: + q = self.norm_q(q) + k = self.norm_k(k) + + # Unlike standard positional embeddings rotary embeddings must be applied at every layer. + q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) + + # Apply kv-cache during inference. + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_caches()`") + k, v = self.kv_cache(input_pos, k, v) + if input_pos_maxp1 is not None: + # Subselect along sequence dimension + k = k[..., :input_pos_maxp1, :] + v = v[..., :input_pos_maxp1, :] + # k, v: (B, nh_k, input_pos_maxp1, hs) + # If input_pos_maxp1 is None -> max_seq_length + + # Grouped queries: balance the number of heads across all three matrices. + # NOTE: flash attention requires it in training mode. + # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. + if n_query_groups != n_head and (input_pos is None or n_query_groups != 1): + q_per_kv = n_head // n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + + if self.apply_sliding_window_attention: + """ + Global Window Sliding window Sliding window + attention mask + bias = attention mask + ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ + │ True False False False │ │ True True True True │ │ True False False False │ + │ True True False False │ │ True True True True │ │ True True False False │ + │ True True True False │ │ False True True True │ │ False True True False │ + │ True True True True │ │ False False True True │ │ False False True True │ + └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ + """ + minus_infty = torch.finfo(q.dtype).min + # minus_infty = float("-inf") + if mask is None: + mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) + mask.masked_fill_(mask.bool(), minus_infty) + mask = mask.view(1, 1, *mask.shape) + sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size) + sliding_window_bias.masked_fill_(sliding_window_bias.bool(), minus_infty) + mask += sliding_window_bias + + # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) + y = self.scaled_dot_product_attention(q, k, v, mask) + + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, head_size * n_head) + + # Output projection. + return self.proj(y) # (B, T, C) + + # Note: All internal computations are done in `float32`. This is also done + # in `F.scaled_dot_product_attention`. + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) + + # with softcapping we cannot use SDPA + if self.config.attention_logit_softcapping is not None: + dtype = torch.float32 + scores = q.to(dtype) @ k.mT.to(dtype) * scale + scores = do_softcapping(scores, self.config.attention_logit_softcapping) + if mask is None: + q_len = q.shape[2] + mask = torch.ones( + q_len, + q_len, + dtype=dtype, + device=q.device, + ).triu(diagonal=1) + mask.masked_fill_(mask.bool(), torch.finfo(dtype).min) + mask = mask.view(1, 1, *mask.shape) + scores = scores + mask + scores = F.softmax(scores, dim=-1) + y = (scores @ v.to(dtype)).to(q.dtype) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + +def rope_cache_OLD( + config: Config, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if config.rope_adjustments is None: + extra_config = None + + else: + adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] + params_present = [param in config.rope_adjustments for param in adjusted_params_required] + num_params_present = sum(params_present) + + if num_params_present == 0: + extra_config = None # uses standard RoPE + elif num_params_present == 4: + # These parameters should always be used together so that we don't interfere with standard rope + extra_config = {name: config.rope_adjustments[name] for name in adjusted_params_required} + elif "factor" in config.rope_adjustments: + # linear RoPE + adjusted_params_required = ["factor"] + extra_config = {name: config.rope_adjustments[name] for name in adjusted_params_required} + else: + # Some but not all parameters are specified; raise an error + missing_params = [param for param, present in zip(adjusted_params_required, params_present) if not present] + raise ValueError( + f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " + "All adjusted RoPE parameters must be specified together." + ) + + return build_rope_cache( + seq_len=config.block_size, + n_elem=config.rope_n_elem, + device=device, + condense_ratio=config.rope_condense_ratio, + base=config.rope_base, + extra_config=extra_config, + rope_local_base_freq=config.rope_local_base_freq, + ) + + +@pytest.mark.parametrize( + "model_name", + ["gemma-2-27b", "gemma-3-27b-it"], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +def test_multi_head_attention_for_gemma(model_name, dtype): + """ + Compares multi-head attention in old and current code, using a + setup from :func:`test_against_original_gemma_2` above. + + """ + num_repeats = 20 + T = 20 + batch_size = 4 + is_gemma_3 = model_name.startswith("gemma-3") + config = Config.from_name( + model_name, + block_size=T, + sliding_window_size=T // 2, + n_layer=2, + n_query_groups=16, + n_head=16, + n_embd=32, + intermediate_size=86, + rotary_percentage=1.0, + rope_indices=[0, 1] if is_gemma_3 else None, + ) + + # Obtain RoPE parameters and compare + model_new = GPT(config).to(dtype=dtype) + model_new.max_seq_length = T + cos_new = model_new.cos.unsqueeze(0) + sin_new = model_new.sin.unsqueeze(0) + cos_old, sin_old = rope_cache_OLD(config) + cos_old = cos_old.unsqueeze(0).to(dtype=dtype) + sin_old = sin_old.unsqueeze(0).to(dtype=dtype) + torch.testing.assert_close(cos_new, cos_old) + torch.testing.assert_close(sin_new, sin_old) + + mha = MultiHeadSelfAttention(config) + shape = (batch_size, T, config.n_embd) + for rep in range(num_repeats): + block_idx = rep % 2 + attn_new = CausalSelfAttention( + config, + block_idx=block_idx, + ).to(dtype=dtype) + attn_old = CausalSelfAttention_OLD( + config, + block_idx=block_idx, + ).to(dtype=dtype) + # Ensure they have the same weights + attn_old.load_state_dict(attn_new.state_dict()) + inputs = torch.randn(shape, dtype=dtype) + token_idx = torch.randint( + 0, + config.padded_vocab_size, + (batch_size, T), + dtype=torch.int64, + ) + if is_gemma_3: + _cos = cos_new[..., config.rope_indices[block_idx]] + _sin = sin_new[..., config.rope_indices[block_idx]] + else: + _cos = cos_new + _sin = sin_new + outputs_new = attn_new( + x=inputs, + cos=_cos, + sin=_sin, + token_idx=token_idx, + mha=mha, + ) + if is_gemma_3: + _cos = cos_old[..., config.rope_indices[block_idx]] + _sin = sin_old[..., config.rope_indices[block_idx]] + else: + _cos = cos_old + _sin = sin_old + outputs_old = attn_old( + x=inputs, + cos=_cos, + sin=_sin, + mask=None, + ) + torch.testing.assert_close(outputs_new, outputs_old) + + +def _get_token_positions( + start: int, + end: int, + batch_size: int, + n_query_groups: int, + device: torch.device, +) -> torch.Tensor: + return ( + torch.arange(start, end, dtype=torch.int64, device=device) + .view( + 1, + 1, + -1, + ) + .expand(batch_size, n_query_groups, -1) + ) + + +@pytest.mark.parametrize( + "seq_len, sliding_window_size", + [ + (128, None), + (21, None), + (128, 16), + (21, 12), + ], +) +def test_build_mask(seq_len, sliding_window_size): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + num_repeats = 4 + batch_size = 2 + n_query_groups = 4 + kwargs = dict(device=torch.device("cpu"), dtype=torch.float32) + tp_kwargs = dict( + batch_size=batch_size, + n_query_groups=n_query_groups, + device=torch.device("cpu"), + ) + + mask_full = build_mask_cache( + max_seq_length=seq_len, + sliding_window_size=sliding_window_size, + **kwargs, + )[None, None, :, :].expand(batch_size, n_query_groups, -1, -1) + token_positions = _get_token_positions(0, seq_len, **tp_kwargs) + for _ in range(num_repeats): + mask_parts = [] + num_prefill = random.randint(1, seq_len - 1) + mask_parts.append( + build_mask_slice( + input_pos=0, + num=num_prefill, + token_positions=token_positions, + n_head=n_query_groups, + **kwargs, + sliding_window_size=sliding_window_size, + ) + ) + for pos in range(num_prefill, seq_len): + mask_parts.append( + build_mask_slice( + input_pos=pos, + num=1, + token_positions=token_positions, + n_head=n_query_groups, + **kwargs, + sliding_window_size=sliding_window_size, + ) + ) + mask_comp = torch.cat(mask_parts, dim=2) + torch.testing.assert_close(mask_full, mask_comp) diff --git a/tests/test_batch.py b/tests/test_batch.py index 32eb1c2f3a..c9fb349081 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -31,7 +31,10 @@ def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]: init="random", ) model: GPT = llm.model - model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length, device=device) + model.set_kv_caches( + batch_size=batch_size, + max_seq_length=max_seq_length, + ) return llm, model @@ -41,8 +44,9 @@ def test_batched_equivalence(tmp_path): model_name = "microsoft/phi-2" download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 + max_seq_length = 50 sample_kwargs = {"top_k": 1} llm: LLM = LLM.load( @@ -51,7 +55,7 @@ def test_batched_equivalence(tmp_path): init="random", ) model: GPT = llm.model - model.set_kv_cache(batch_size=1, max_seq_length=50, device=device) + model.set_kv_caches(batch_size=1, max_seq_length=50) input_pos_1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device) input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device) @@ -65,8 +69,18 @@ def test_batched_equivalence(tmp_path): batch_x1 = torch.stack([x] * batch_size, dim=0) # Single token generation baseline - tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs) - tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs) + tok_1 = next_token( + model=model, + x=x.unsqueeze(0), + input_pos=0, + **sample_kwargs, + ) + tok_2 = next_token( + model=model, + x=tok_1.unsqueeze(0), + input_pos=x.shape[0], + **sample_kwargs, + ) assert tok_1.ndim == 1 assert tok_2.ndim == 1 @@ -74,11 +88,24 @@ def test_batched_equivalence(tmp_path): assert tok_2.size(0) == 1 # Switch to batched generation - model.clear_kv_cache() - model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0") + model.clear_kv_caches() + model.set_kv_caches( + batch_size=batch_size, + max_seq_length=max_seq_length, + ) - toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs) - toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs) + toks_1: torch.Tensor = batched_next_token( + model=model, + x=batch_x1, + input_pos=0, + kwargs=sample_kwargs, + ) + toks_2: torch.Tensor = batched_next_token( + model=model, + x=toks_1, + input_pos=x.shape[0], + kwargs=sample_kwargs, + ) assert toks_1.ndim == 2 assert toks_2.ndim == 2 @@ -97,27 +124,21 @@ def test_simple_batch(): config = litgpt.Config.from_name("microsoft/phi-2", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256) with torch.device("cuda"): m = litgpt.GPT(config).requires_grad_(False).eval() - x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]]) - input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]]) + m.max_seq_length = 10 + # Note: This KV cache can be used throughout, also for batch size 1 + # It is reset whenever `input_pos=0` (prefill) + m.set_kv_caches(2) + x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) x1 = torch.tensor([[1], [2]]) - input_pos1 = torch.tensor([[4], [3]]) - with torch.device("cuda"): - m.set_kv_cache(2) - outs0 = m(x0, input_pos0) - outs1 = m(x1, input_pos1) - - with torch.device("cuda"): - m.set_kv_cache(1) - - outs0_ref0 = m(x0[:1], input_pos0[0]) - outs1_ref0 = m(x1[:1], input_pos1[0]) + outs0 = m(x0, input_pos=0) + outs1 = m(x1, input_pos=4) - with torch.device("cuda"): - m.set_kv_cache(1) + outs0_ref0 = m(x0[:1], input_pos=0) + outs1_ref0 = m(x1[:1], input_pos=4) - outs0_ref1 = m(x0[1:], input_pos0[1]) - outs1_ref1 = m(x1[1:], input_pos1[1]) + outs0_ref1 = m(x0[1:], input_pos=0) + outs1_ref1 = m(x1[1:], input_pos=4) outs0_ref = torch.cat([outs0_ref0, outs0_ref1]) outs1_ref = torch.cat([outs1_ref0, outs1_ref1]) @@ -133,7 +154,7 @@ def test_simple_batch(): def test_batch_generate(tmp_path): torch.use_deterministic_algorithms(True) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) @@ -151,12 +172,11 @@ def test_batch_generate(tmp_path): # Generate tokens tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, include_prompt=True, - include_eos=False, ): tokens.append([t.item() if t is not None else None for t in l]) @@ -216,13 +236,12 @@ def find_unique_stop(triplets): # Now we generate again, stopping early at the stop tokens. tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, - stop_tokens=[(s,) for s in stops], + stop_tokens=tuple([s] for s in stops), sample_args=sample_kwargs, include_prompt=True, - include_eos=False, ): tokens.append([t.item() if t is not None else None for t in l]) @@ -257,7 +276,7 @@ def find_unique_stop(triplets): def test_batch_generate_equivalence(tmp_path): torch.use_deterministic_algorithms(True) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) @@ -276,12 +295,11 @@ def test_batch_generate_equivalence(tmp_path): batch_tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, include_prompt=False, - include_eos=False, ): batch_tokens.append([t.item() if t is not None else None for t in l]) @@ -292,7 +310,7 @@ def test_batch_generate_equivalence(tmp_path): tokens = [] for t in generate_fn( - model, + model=model, prompt=batch_x[0], max_returned_tokens=50, include_prompt=False, diff --git a/tests/test_chat.py b/tests/test_chat.py index 3bfe49780d..fa9280e53c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -46,6 +46,9 @@ def test_generate(monkeypatch, generated, stop_tokens, expected): model = MagicMock() model.config.block_size = 100 model.max_seq_length = 100 + # Mock methods called during generation + monkeypatch.setattr(model, "kv_cache_max_prefill_length", lambda: 80) + monkeypatch.setattr(model, "kv_cache_max_tokens_forward", lambda: 20) it = iter(generated) def multinomial(*_, **__): diff --git a/tests/test_generate_speculatively.py b/tests/test_generate_speculatively.py index 383688c43d..6392675ae4 100644 --- a/tests/test_generate_speculatively.py +++ b/tests/test_generate_speculatively.py @@ -30,10 +30,14 @@ def forward(self, idx, **kwargs): target_model = TargetModel() token = torch.tensor([-1]) - input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( - draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs + draft_model=draft_model, + target_model=target_model, + token=token, + input_pos=0, + speculative_k=3, + **sample_kwargs, ) # target model never accepts draft model's output, thus the output of the `speculative_decoding` @@ -56,10 +60,14 @@ def forward(self, idx, **kwargs): target_model = TargetModel() token = torch.tensor([-1]) - input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( - draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs + draft_model=draft_model, + target_model=target_model, + token=token, + input_pos=0, + speculative_k=3, + **sample_kwargs, ) # target model always accepts draft model's output, thus the output of the `speculative_decoding` @@ -89,10 +97,14 @@ def forward(self, idx, **kwargs): target_model = TargetModel() token = torch.tensor([-1]) - input_pos = torch.tensor([0]) sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0) # to make sampling consistent output = generate.speculative_decoding( - draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs + draft_model=draft_model, + target_model=target_model, + token=token, + input_pos=0, + speculative_k=3, + **sample_kwargs, ) # target model accepts only 2 out of 3 draft model's output, thus the output of the `speculative_decoding` @@ -114,11 +126,16 @@ def test_generate(max_seq_length, speculative_k): target_model = GPT(Config(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16)) for model in (draft_model, target_model): model.max_seq_length = max_seq_length - model.set_kv_cache(batch_size=1) + model.set_kv_caches(batch_size=1) # generate tokens out, acceptance_rate = generate.generate( - draft_model, target_model, input_idx, T + max_new_tokens, top_k=1, speculative_k=speculative_k + draft_model=draft_model, + target_model=target_model, + prompt=input_idx, + max_returned_tokens=T + max_new_tokens, + top_k=1, + speculative_k=speculative_k, ) # validate @@ -185,6 +202,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): ANY, tensor_like, 53, + prompt_chunksize=16, temperature=2.0, top_k=2, top_p=0.9, diff --git a/tests/test_lora.py b/tests/test_lora.py index c7a31f9609..0f1d10ff7a 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -520,9 +520,8 @@ def test_lora_compile(): assert explanation.graph_break_count == 0 model = LoRAGPT(model.config) - model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + model.set_kv_caches(2) + explanation = torch._dynamo.explain(model)(x, input_pos=0) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 diff --git a/tests/test_model.py b/tests/test_model.py index da931c7297..79117bcb96 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,7 +2,6 @@ from copy import deepcopy from functools import partial -from unittest import mock import pytest import torch @@ -36,7 +35,8 @@ import litgpt.config as config_module from litgpt import GPT, Config -from litgpt.model import CausalSelfAttention, batched_index_copy_ +from litgpt.attention import DefaultKeysAndValues +from litgpt.model import CausalSelfAttention from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, @@ -769,7 +769,13 @@ def test_against_original_gemma(model_name, device, dtype): torch.set_default_dtype(dtype) T = 5 - ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + ours_config = Config.from_name( + model_name, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, @@ -834,6 +840,7 @@ def test_against_original_gemma_2(model_name, device, dtype): n_head=16, n_embd=32, intermediate_size=86, + rotary_percentage=1.0, # Gemma2 does not have this ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, @@ -868,7 +875,6 @@ def test_against_original_gemma_2(model_name, device, dtype): # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) - assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @@ -1398,57 +1404,21 @@ def test_model_compile(): assert explanation.graph_break_count == 0 model = GPT(model.config) - model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + model.set_kv_caches(2) + explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 -@torch.inference_mode() -@pytest.mark.parametrize( - "max_seq_length", (25, pytest.param(23, marks=pytest.mark.xfail(raises=IndexError, strict=True))) -) -@pytest.mark.flaky(reruns=5) -def test_kv_cache(max_seq_length): - config = Config(block_size=25, padded_vocab_size=5, n_layer=2, n_head=2, n_embd=8) - model = GPT(config) - idx = torch.randint(0, model.config.padded_vocab_size, (1, 5)) - max_new_tokens = 20 - model.max_seq_length = max_seq_length - model.set_kv_cache(1) - - def generate(logits): - logits = logits[:, -1:] - probs = torch.nn.functional.softmax(logits, dim=-1) - return torch.argmax(probs).unsqueeze(0).unsqueeze(0) - - x_no_cache = idx - x_cache = idx - input_pos = torch.arange(0, 5) - for _ in range(max_new_tokens): - logits_no_cache = model(x_no_cache[:, -max_seq_length:]) - out_no_cache = generate(logits_no_cache) - - logits_cache = model(x_cache, input_pos) - out_cache = generate(logits_cache) - - torch.testing.assert_close(out_no_cache, out_cache, rtol=0, atol=0) - - x_no_cache = torch.cat((x_no_cache, out_no_cache), dim=1) - x_cache = out_cache - input_pos = input_pos[-1:] + 1 - - @torch.inference_mode() def test_model_kv_cache_amp(): config = Config.from_name("pythia-14m", n_layer=2) model = GPT(config) - encoded = torch.arange(45) - model.set_kv_cache(batch_size=1) + encoded = torch.arange(45).view(1, -1) + model.set_kv_caches(batch_size=1) with torch.autocast("cpu", torch.bfloat16): - output = model(encoded.unsqueeze(0), encoded) + output = model(encoded, input_pos=0) assert output.dtype is torch.bfloat16 @@ -1466,28 +1436,49 @@ def test_sdpa_choice(config): pytest.skip("Gemma 2 doesn't support SDPA") torch.set_default_dtype(torch.float16) + config["n_layer"] = 1 + config = config_module.Config(**config) + enable_gqa = config.n_query_groups < config.n_head - def assert_sdpa_backend(original_fn, q, k, v, mask): + def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores): # SDPAParams gained an additional argument in PyTorch 2.5 args = [] + assert k_and_v.both_in_parallel() + # This is also done in `MultiHeadSelfAttention.scaled_dot_product_attention` + if mask is None and enable_gqa: + # Some efficient kernels have not implemented + # `enabla_gqa=True`. It is better to extend keys, values in + # this case. + key = k_and_v.keys() + value = k_and_v.values() + q_per_kv = config.n_head // config.n_query_groups + key = key.repeat_interleave(q_per_kv, dim=1) + value = value.repeat_interleave(q_per_kv, dim=1) + assert query.shape[1] == key.shape[1] + _k_and_v = DefaultKeysAndValues(key, value) + _enable_gqa = False + else: + _enable_gqa = enable_gqa + _k_and_v = k_and_v + if hasattr(SDPAParams, "enable_gqa"): - args.append(False) - params = SDPAParams(q, k, v, mask, 0.0, True, *args) + args.append(_enable_gqa) + params = SDPAParams(query, _k_and_v.keys(), _k_and_v.values(), mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: assert flash_sdp_enabled(), "flash_sdp_enabled() is False" if config.sliding_window_size is None: assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False" elif expected is SDPBackend.EFFICIENT_ATTENTION: assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False" - assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False" + if (not enable_gqa) or mask is None: + # At present, `SDPBackend.EFFICIENT_ATTENTION` does not support + # `enabla_gqa=True` and a mask specified + assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False" elif expected is SDPBackend.MATH: assert math_sdp_enabled(), "math_sdp_enabled() is False" else: raise NotImplementedError - return original_fn(q, k, v, mask) - - config["n_layer"] = 1 - config = config_module.Config(**config) + return original_fn(query, k_and_v, mask, return_scores) try: with torch.device("cuda"): @@ -1497,8 +1488,10 @@ def assert_sdpa_backend(original_fn, q, k, v, mask): # best effort, if the GPU can load it pytest.xfail() - for h in model.transformer.h: - h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) + model.mha.scaled_dot_product_attention = partial( + assert_sdpa_backend, + model.mha.scaled_dot_product_attention, + ) if SUPPORTS_FLASH_ATTENTION: expected = SDPBackend.FLASH_ATTENTION @@ -1515,53 +1508,77 @@ def assert_sdpa_backend(original_fn, q, k, v, mask): @torch.inference_mode() def test_sdpa_choice_kv_cache(config): torch.set_default_dtype(torch.float16) + config["n_layer"] = 1 + config = config_module.Config(**config) + enable_gqa = config.n_query_groups < config.n_head - def assert_sdpa_backend(original_fn, q, k, v, mask): + def assert_sdpa_backend(original_fn, query, k_and_v, mask, return_scores): # SDPAParams gained an additional argument in PyTorch 2.5 args = [] + assert k_and_v.both_in_parallel() + # This is also done in `MultiHeadSelfAttention.scaled_dot_product_attention` + if mask is None and enable_gqa: + # Some efficient kernels have not implemented + # `enabla_gqa=True`. It is better to extend keys, values in + # this case. + key = k_and_v.keys() + value = k_and_v.values() + q_per_kv = config.n_head // config.n_query_groups + key = key.repeat_interleave(q_per_kv, dim=1) + value = value.repeat_interleave(q_per_kv, dim=1) + assert query.shape[1] == key.shape[1] + _k_and_v = DefaultKeysAndValues(key, value) + _enable_gqa = False + else: + _enable_gqa = enable_gqa + _k_and_v = k_and_v + if hasattr(SDPAParams, "enable_gqa"): - args.append(False) - params = SDPAParams(q, k, v, mask, 0.0, True, *args) + args.append(_enable_gqa) + params = SDPAParams(query, _k_and_v.keys(), _k_and_v.values(), mask, 0.0, True, *args) if expected is SDPBackend.FLASH_ATTENTION: - assert flash_sdp_enabled() - assert can_use_flash_attention(params, True) + assert flash_sdp_enabled(), "flash_sdp_enabled() is False" + assert can_use_flash_attention(params, True), "can_use_flash_attention(params, True) is False" elif expected is SDPBackend.EFFICIENT_ATTENTION: - assert mem_efficient_sdp_enabled() - assert can_use_efficient_attention(params, True) + assert mem_efficient_sdp_enabled(), "mem_efficient_sdp_enabled() is False" + if (not enable_gqa) or mask is None: + # At present, `SDPBackend.EFFICIENT_ATTENTION` does not support + # `enabla_gqa=True` and a mask specified + assert can_use_efficient_attention(params, True), "can_use_efficient_attention(params, True) is False" elif expected is SDPBackend.MATH: - assert math_sdp_enabled() + assert math_sdp_enabled(), "math_sdp_enabled() is False" else: raise NotImplementedError - return original_fn(q, k, v, mask) - - config["n_layer"] = 1 - config = config_module.Config(**config) + return original_fn(query, k_and_v, mask, return_scores) try: with torch.device("cuda"): model = GPT(config) - model.max_seq_length = 1 - model.set_kv_cache(2) - x = torch.randint(0, 10, (2, 1), dtype=torch.int32) - input_pos = torch.tensor([0], dtype=torch.long) + model.max_seq_length = 16 + model.set_kv_caches(2) + x = torch.randint(0, 10, (2, 10), dtype=torch.int32) except torch.cuda.OutOfMemoryError: # best effort, if the GPU can load it pytest.xfail() - for h in model.transformer.h: - h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) + for block in model.transformer.h: + kv_cache = block.attn.kv_cache + kv_cache.mha.scaled_dot_product_attention = partial( + assert_sdpa_backend, + kv_cache.mha.scaled_dot_product_attention, + ) if SUPPORTS_FLASH_ATTENTION: # flash attention does not support an attention mask expected = SDPBackend.MATH with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): - model(x, input_pos) + model(x, input_pos=0) expected = ( SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 and config.n_query_groups != 1 else SDPBackend.MATH ) with torch.backends.cuda.sdp_kernel(enable_flash=False): - model(x, input_pos) + model(x, input_pos=0) @_RunIf(min_cuda_gpus=2, standalone=True) @@ -1592,68 +1609,6 @@ def test_reset_parameters_device(): assert model.cos.device.type == "cuda" -def test_batched_index_copy_modes(): - # Mock the torch.backends.mps.is_available() function to simulate MPS availability - with mock.patch("torch.backends.mps.is_available", return_value=True): - # Mock the device type to simulate the "mps" device - with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device: - mock_device.return_value = torch.device("mps") - - # Test case when idx.dim() == 1 - t_original_1 = torch.randn(3, 5) - dim_1 = 0 - idx_1 = torch.tensor([0, 2]) - val_1 = torch.randn(2, 5) - - t1_cpu = t_original_1.clone() - t1_mps = t_original_1.clone() - - # Perform the index copy on CPU - batched_index_copy_(t1_cpu, dim_1, idx_1, val_1) - - # Simulate the MPS index copy - idx_1_mps = idx_1 - val_1_mps = val_1 - batched_index_copy_(t1_mps, dim_1, idx_1_mps, val_1_mps) - assert torch.allclose(t1_cpu, t1_mps), "Mismatch with idx.dim() == 1 on mocked MPS" - - # Test case when idx.dim() == 2 - t_original_2 = torch.randn(2, 5, 4) - dim_2 = 1 - idx_2 = torch.tensor([[0, 2], [1, 3]]) - val_2 = torch.randn(2, 2, 4) - - t2_cpu = t_original_2.clone() - t2_mps = t_original_2.clone() - - # Perform the index copy on CPU - batched_index_copy_(t2_cpu, dim_2, idx_2, val_2) - - # Simulate the MPS index copy - idx_2_mps = idx_2 - val_2_mps = val_2 - batched_index_copy_(t2_mps, dim_2, idx_2_mps, val_2_mps) - assert torch.allclose(t2_cpu, t2_mps), "Mismatch with idx.dim() == 2 on mocked MPS" - - # Additional test with negative dimension - t_original_3 = torch.randn(2, 3, 4) - dim_3 = -2 - idx_3 = torch.tensor([[0, 1], [1, 2]]) - val_3 = torch.randn(2, 2, 4) - - t3_cpu = t_original_3.clone() - t3_mps = t_original_3.clone() - - # Perform the index copy on CPU - batched_index_copy_(t3_cpu, dim_3, idx_3, val_3) - - # Simulate the MPS index copy - idx_3_mps = idx_3 - val_3_mps = val_3 - batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) - assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" - - def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( @@ -1691,7 +1646,7 @@ def test_kv_cache_buffer_shape(n_query_groups): ) model = GPT(config) model.max_seq_length = max_seq_length - model.set_kv_cache(batch_size) + model.set_kv_caches(batch_size) required_shape = (batch_size, n_query_groups, max_seq_length, config.head_size) for block in model.transformer.h: kv_cache = block.attn.kv_cache @@ -1716,22 +1671,3 @@ def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim) required_shape = (config.block_size, final_dim) assert model.cos.shape == required_shape assert model.sin.shape == required_shape - - -def test_forward_with_without_input_pos_maxp1(): - batch_size = 3 - config = Config( - block_size=25, - padded_vocab_size=5, - n_layer=2, - n_head=8, - n_embd=16, - ) - model = GPT(config) - model.set_kv_cache(batch_size) - idx = torch.randint(0, config.padded_vocab_size, (1, 10)) - input_pos = torch.arange(1, 11) - input_pos_maxp1 = 11 - logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) - logits_no_maxp1 = model(idx, input_pos) - torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)