Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import pytest
import torch

from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
FreeKVCacheBlockQueue, KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens,
Expand Down Expand Up @@ -426,3 +428,45 @@ def new_kv_cache_spec(block_size=16,
]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)


@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),
("Qwen/Qwen1.5-7B", 16383, 16383),
])
def test_estimate_max_model_len(model_id, max_model_len,
want_estimated_max_len):
# Create a VllmConfig
model_config = ModelConfig(
model_id,
task="generate",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
max_model_len=max_model_len,
)
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)

vllm_config = VllmConfig(
model_config=model_config,
scheduler_config=scheduler_config,
)

# Create KV cache specs
kv_cache_spec = {}
for i in range(32):
layer_name = f"layer_{i}"
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=16,
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
# Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
8 * GiB_bytes)
assert estimated_max_len == want_estimated_max_len
65 changes: 61 additions & 4 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
Expand Down Expand Up @@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return ret


def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.

Returns:
The estimated maximum model length that can fit in the available memory.
"""

# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = sum(
(layer_spec.max_memory_usage_bytes(vllm_config)
for layer_spec in kv_cache_spec.values()),
start=0,
)
return memory_needed <= available_memory

# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max

# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0

# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result


def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
Expand Down Expand Up @@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)

if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = " Based on the available memory,"
f" the estimated maximum model length is {estimated_max_len}."

raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
f"{estimated_msg} "
f" Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")


Expand Down