Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ce630ea
WiP adding support for Mamba
tlrmchlsmth Jul 8, 2024
6c59b06
wip
tlrmchlsmth Jul 9, 2024
eb9bf34
WIP -- runs through. Generates tokens. Bad tokens.
tlrmchlsmth Jul 10, 2024
320f79b
Good output for mamba-370m
tlrmchlsmth Jul 15, 2024
5ab6622
wip
tlrmchlsmth Jul 16, 2024
71173a0
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
25b54d9
cleanup
tlrmchlsmth Jul 16, 2024
ebc12f1
Rename embedding block space manager
tlrmchlsmth Jul 16, 2024
ac60374
cleanup
tlrmchlsmth Jul 16, 2024
adb6713
remove file
tlrmchlsmth Jul 16, 2024
b733a84
format
tlrmchlsmth Jul 16, 2024
fb846ce
apply fix from #6214
tlrmchlsmth Jul 16, 2024
09b1495
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 16, 2024
d8017cb
fixes from 6425
tlrmchlsmth Jul 16, 2024
7ab2b9e
add an integration test
tlrmchlsmth Jul 23, 2024
c319a21
lint
tlrmchlsmth Jul 23, 2024
3374d8f
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Jul 31, 2024
76022d3
fixup
tlrmchlsmth Jul 31, 2024
9ffc057
backend selector changes
tlrmchlsmth Jul 31, 2024
65d7e22
lint
tlrmchlsmth Jul 31, 2024
f14648e
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Aug 20, 2024
e76a617
Factor out mamba cache from jamba.py, and fixes
tlrmchlsmth Aug 20, 2024
b9723fe
Fix mamba cache initialized bool. format and renames
tlrmchlsmth Aug 21, 2024
b2a8cd8
Refactor mamba to use the MambaCacheManager
tlrmchlsmth Aug 21, 2024
9ba8734
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 28, 2024
f87a8e2
fixes
tlrmchlsmth Aug 29, 2024
06b146e
Merge branch 'upstream-main' into tms/add_mamba
tlrmchlsmth Aug 29, 2024
8e16aca
Update to use kernels from #7651
tlrmchlsmth Aug 29, 2024
120b761
some cruft
tlrmchlsmth Aug 29, 2024
698f666
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 13, 2024
a5bd7d2
Move test_mamba.py (for #7820)
tlrmchlsmth Sep 13, 2024
6546bd9
fixes
tlrmchlsmth Sep 13, 2024
f42af9b
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Sep 23, 2024
85a8378
Review comments
tlrmchlsmth Sep 24, 2024
80e3c77
cache attention free
tlrmchlsmth Sep 24, 2024
184e808
fixup
tlrmchlsmth Sep 24, 2024
05d6aab
fixup
tlrmchlsmth Sep 24, 2024
4ebd4cc
missed two
tlrmchlsmth Sep 24, 2024
ca3788e
Remove is_attention_free from SchedulerConfig
tlrmchlsmth Sep 24, 2024
c67a650
default `is_attention_free` for unit tests
tlrmchlsmth Sep 25, 2024
9e2edf6
Fix attention selector tests
tlrmchlsmth Sep 25, 2024
f41b474
merge main, support chunked prefill, more tests
tlrmchlsmth Sep 30, 2024
7ef3c68
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
8729b43
Review comments
tlrmchlsmth Oct 10, 2024
5fb01c4
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 10, 2024
16d3f1d
format
tlrmchlsmth Oct 10, 2024
4b21a08
Fix supported_models.rst
tlrmchlsmth Oct 10, 2024
ec8ef04
jambafix
tlrmchlsmth Oct 10, 2024
49e1f3c
fix softfail on cpu tests
tlrmchlsmth Oct 11, 2024
e80b82a
Merge branch 'main' into tms/add_mamba
tlrmchlsmth Oct 11, 2024
609e9fb
fix for #9233
tlrmchlsmth Oct 11, 2024
93129e5
format
tlrmchlsmth Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .buildkite/run-cpu-test-ppc64le.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" \
--ignore=tests/models/test_embedding.py \
--ignore=tests/models/test_oot_registration.py \
--ignore=tests/models/test_registry.py \
--ignore=tests/models/test_jamba.py \
--ignore=tests/models/test_mamba.py \
--ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported

# online inference
docker exec cpu-test bash -c "
Expand Down
1 change: 1 addition & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ docker exec cpu-test bash -c "
pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \
--ignore=tests/models/decoder_only/language/test_mamba.py \
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

Expand Down
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ Text Generation
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
- ✅︎
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
Expand Down
37 changes: 21 additions & 16 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
False)
assert backend.name == name


Expand All @@ -46,37 +46,42 @@ def test_flash_attn(monkeypatch):

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
assert backend.name != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
which_attn_to_use(16, None, torch.float16, None, 16, False)
Loading