Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/external_projects/flashmla.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down Expand Up @@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
)

set(FlashMLA_INCLUDES
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/ops/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def get_mla_metadata(
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens,
num_q_tokens_per_head_k,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(

self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")

device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
Expand Down Expand Up @@ -123,6 +124,7 @@ def _build_decode(
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)

# TODO: we can disambiguate between decode and mixed-prefill decode here
Expand Down