Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/v1/attention/test_attention_backends_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""

from types import SimpleNamespace

import pytest

from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.models.minimax_text_01 import (
MiniMaxText01LinearAttention)
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)


@pytest.mark.parametrize(
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
(
MambaMixer,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
time_step_rank=8,
use_conv_bias=True,
use_bias=False,
use_rms_norm=True,
),
Mamba1AttentionBackend,
"mamba1",
),
(
MambaMixer2,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
use_conv_bias=True,
use_bias=False,
n_groups=1,
num_heads=8,
head_dim=32,
),
Mamba2AttentionBackend,
"mamba2",
),
(
MiniMaxText01LinearAttention,
dict(
hidden_size=128,
hidden_inner_size=256,
num_heads=8,
head_dim=32,
max_position=2048,
block_size=64,
num_hidden_layer=12,
layer_idx=0,
linear_layer_idx=0,
),
LinearAttentionBackend,
"linear_attention",
),
(
ShortConv,
dict(
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
dim=128,
layer_idx=0,
),
ShortConvAttentionBackend,
"short_conv",
),
])
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
expected_backend, expected_mamba_type):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)

backend_class = layer.get_attn_backend()
assert backend_class is expected_backend
assert layer.mamba_type == expected_mamba_type


@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
(ShortConv, ShortConvAttentionBackend, "short_conv"),
])
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
expected_mamba_type):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert hasattr(layer_class, 'get_attn_backend'), (
f"{layer_class.__name__} should have get_attn_backend method")
assert hasattr(layer_class, 'mamba_type'), (
f"{layer_class.__name__} should have mamba_type property")
25 changes: 0 additions & 25 deletions tests/v1/attention/test_mamba_selectors.py

This file was deleted.

3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -54,7 +55,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS


class Attention(nn.Module):
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.

This class takes query, key, and value tensors as input. The input tensors
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/attention_layer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend


class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.

This provides a common interface for getting attention backends
from different layer types.
"""

@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass
15 changes: 13 additions & 2 deletions vllm/model_executor/layers/mamba/abstract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING

import torch

from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase

class MambaBase(ABC):
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend


class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
Expand All @@ -32,3 +38,8 @@ def get_state_shape(self) -> Iterable[tuple[int, ...]]:
@abstractmethod
def mamba_type(self) -> str:
pass

@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

import torch
from torch import nn
Expand Down Expand Up @@ -404,6 +407,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
def mamba_type(self) -> str:
return "mamba1"

def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
return Mamba1AttentionBackend

def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

import torch
from torch import nn
Expand Down Expand Up @@ -758,6 +761,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
def mamba_type(self) -> str:
return "mamba2"

def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend


def mamba_mixer2(
hidden_states: torch.Tensor,
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/mamba/short_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

import torch

Expand Down Expand Up @@ -232,6 +235,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]:
def mamba_type(self) -> str:
return "short_conv"

def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
return ShortConvAttentionBackend


def short_conv(
hidden_states: torch.Tensor,
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import copy
import math
from collections.abc import Iterable
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

import regex as re
import torch
Expand Down Expand Up @@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"

def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend

def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
Expand Down
22 changes: 0 additions & 22 deletions vllm/v1/attention/backends/mamba_selectors.py

This file was deleted.

26 changes: 8 additions & 18 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
Expand All @@ -55,7 +56,6 @@
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
Expand Down Expand Up @@ -2752,11 +2752,13 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)

def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than using
Expand All @@ -2765,7 +2767,7 @@ def get_attn_backends_for_layers(
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backend = layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
Expand Down Expand Up @@ -2794,20 +2796,8 @@ def create_attn_groups(

for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif isinstance(kv_cache_spec, MambaSpec):
attn_backends = {
get_mamba_attn_backend(kv_cache_spec.mamba_type):
kv_cache_group_spec.layer_names
}
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")

attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))

Expand Down