Skip to content

Refactoring of multi-head attention and support for KV caching #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .azure/gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
CUDA_VERSION: "12.6.3"
TORCH_VERSION: "2.7.1"
CUDNN_FRONTEND_VERSION: "1.10.0"
CUBLAS_WORKSPACE_CONFIG: ":4096:8"
container:
# image: "pytorchlightning/pytorch_lightning:base-cuda-py$(PYTHON_VERSION)-torch$(TORCH_VERSION)-cuda$(CUDA_VERSION)"
# pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.5.0-py3.10-pt_main-dev
Expand Down
4 changes: 2 additions & 2 deletions extensions/xla/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ def validate(
encoded = tokenizer.encode(prompt, device=fabric.device)
with fabric.init_tensor():
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
model.set_kv_cache(batch_size=1)
model.set_kv_caches(batch_size=1)
output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8)
model.clear_kv_cache()
model.clear_kv_caches()
output = tokenizer.decode(output)
rank_print(fabric, output)

Expand Down
2 changes: 1 addition & 1 deletion extensions/xla/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main(
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.set_kv_caches(batch_size=1)

t0 = time.perf_counter()
y = generate(
Expand Down
2 changes: 1 addition & 1 deletion extensions/xla/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def main(
for i in range(num_samples):
with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.set_kv_caches(batch_size=1)

t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
Expand Down
110 changes: 71 additions & 39 deletions litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch.nn as nn
from typing_extensions import Self

from litgpt.attention import DefaultKeysAndValues, MultiHeadSelfAttention
from litgpt.config import Config as BaseConfig
from litgpt.kvcache.base import KVCache
from litgpt.model import GPT as BaseModel
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
Expand All @@ -29,21 +31,28 @@ class Config(BaseConfig):

class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
def __init__(self, config: Config, **mha_kwargs) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.lm_head = nn.Linear(
config.n_embd,
config.padded_vocab_size,
bias=config.lm_head_bias,
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.mha = MultiHeadSelfAttention(config, **mha_kwargs)
self.max_seq_length = self.config.block_size
self._start_of_layer_hook = config.start_of_layer_hook
# Have dense KV caches been created by `set_kv_caches`?
self._default_kv_cache = False

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
Expand All @@ -57,56 +66,79 @@ def _init_weights(self, module: nn.Module) -> None:


class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)


class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""

def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
if block_idx >= config.adapter_start_layer:
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(
config=config,
block_idx=block_idx,
kv_cache=kv_cache,
)
self._extend_forward = block_idx >= config.adapter_start_layer
if self._extend_forward:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
# kv cache for inference
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
def _transform_output(
self,
y: torch.Tensor,
query: torch.Tensor,
mha: MultiHeadSelfAttention,
) -> torch.Tensor:
y = super().scaled_dot_product_attention(q, k, v, mask)
if self.block_idx < self.config.adapter_start_layer:
return y

aT = self.config.adapter_prompt_length
if self.adapter_kv_cache is not None:
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
# are the same every call
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
if self.config.n_query_groups != 1:
# for MHA this is a no-op
ak = ak.repeat_interleave(q_per_kv, dim=2)
av = av.repeat_interleave(q_per_kv, dim=2)
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
self.adapter_kv_cache = (ak, av)

T = q.size(2)
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
ay = super().scaled_dot_product_attention(q, ak, av, amask)
return y + self.gating_factor * ay
if self._extend_forward:
B, T, _ = y.shape
y = y.view(B, T, self.config.n_head, self.config.head_size)
aT = self.config.adapter_prompt_length
if self.adapter_kv_cache is not None:
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
# are the same every call
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
if self.config.n_query_groups != 1:
# for MHA this is a no-op
ak = ak.repeat_interleave(q_per_kv, dim=2)
av = av.repeat_interleave(q_per_kv, dim=2)
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
self.adapter_kv_cache = (ak, av)

amask = torch.ones(T, aT, dtype=torch.bool, device=query.device)
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
ay, _ = mha.scaled_dot_product_attention(
query=query,
k_and_v=a_k_and_v,
mask=amask,
)
y = (y + self.gating_factor * ay).view(B, T, -1)

return y

def reset_parameters(self) -> None:
if hasattr(self, "gating_factor"):
Expand Down
38 changes: 31 additions & 7 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from litgpt.adapter import GPT as BaseModel
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.adapter import Config as BaseConfig
from litgpt.attention import MultiHeadSelfAttention
from litgpt.kvcache.base import KVCache
from litgpt.model import Block as BaseBlock
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights
Expand Down Expand Up @@ -69,16 +71,23 @@ def __init__(self, config: Config) -> None:
assert config.padded_vocab_size is not None
self.config = config

self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
self.lm_head = AdapterV2Linear(
config.n_embd,
config.padded_vocab_size,
bias=config.lm_head_bias,
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.mask_cache: Optional[torch.Tensor] = None
self.mha = MultiHeadSelfAttention(config)
self.max_seq_length = self.config.block_size
self._start_of_layer_hook = config.start_of_layer_hook
# Have dense KV caches been created by `set_kv_cache`?
self._default_kv_cache = False

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
Expand All @@ -98,24 +107,39 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa


class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
self.mlp = config.mlp_class(config)


class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""

# Copy&paste from :class:`model.CausalSelfAttention`
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
# key, query, value projections for all heads, but in a batch
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)

@property
def device(self) -> Optional[torch.device]:
w = self.qkv.linear.weight
return None if w is None else w.device

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
Expand Down
42 changes: 22 additions & 20 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
checkpoint_dir: Path = None,
fabric: L.Fabric = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
kv_cache_initialized: bool = False,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None,
) -> None:
super().__init__()
Expand All @@ -57,7 +56,6 @@ def __init__(
self.checkpoint_dir = checkpoint_dir
self.fabric = fabric
self.generate_strategy = generate_strategy
self.kv_cache_initialized = kv_cache_initialized
self.fixed_kv_cache_size = fixed_kv_cache_size
self.prev_generated_seq_length = 0

Expand All @@ -82,6 +80,9 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)

def kv_cache_initialized(self) -> bool:
return all(block.attn.kv_cache is not None for block in self.model.transformer.h)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -249,7 +250,6 @@ def load(
checkpoint_dir=checkpoint_dir,
fabric=fabric,
generate_strategy=None,
kv_cache_initialized=False,
fixed_kv_cache_size=False,
)

Expand Down Expand Up @@ -367,7 +367,6 @@ def distribute(

print("Fabric launched", file=sys.stderr)

self.kv_cache_initialized = False
if generate_strategy is None:
with fabric.init_module(empty_init=(total_devices > 1)):
model = GPT(self.config)
Expand All @@ -383,8 +382,10 @@ def distribute(
kv_cache_size = model.max_seq_length
else:
kv_cache_size = fixed_kv_cache_size
model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)
self.kv_cache_initialized = True
model.set_kv_caches(
batch_size=1,
max_seq_length=kv_cache_size,
)
self.fixed_kv_cache_size = fixed_kv_cache_size

elif generate_strategy in ("sequential", "tensor_parallel"):
Expand Down Expand Up @@ -437,7 +438,7 @@ def distribute(
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.set_kv_caches(batch_size=1)
model.eval()
model = fabric.to_device(model)

Expand All @@ -448,8 +449,6 @@ def distribute(
if fabric.global_rank == 0:
pbar.close()

self.kv_cache_initialized = True

else:
raise ValueError(f"Unsupported generate_strategy: {generate_strategy}")

Expand Down Expand Up @@ -508,20 +507,23 @@ def generate(
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens

if not self.kv_cache_initialized:
if self.fabric is not None:
device = self.fabric.device
else:
device = self.preprocessor.device
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)
self.kv_cache_initialized = True
if self.fabric is not None:
device = self.fabric.device
else:
device = self.preprocessor.device
if not self.kv_cache_initialized():
self.model.set_kv_caches(
batch_size=1,
max_seq_length=max_returned_tokens,
)

# Dynamically grow the kv cache size if necessary
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
tmp_device = self.model.mask_cache.device
self.model.clear_kv_cache()
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)

self.model.clear_kv_caches()
self.model.set_kv_caches(
batch_size=1,
max_seq_length=max_returned_tokens,
)
else:
for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()
Expand Down
Loading
Loading