Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ mkdocs.yaml @hmellor
# Linting
.markdownlint.yaml @hmellor
.pre-commit-config.yaml @hmellor
/tools/pre_commit @hmellor

# CPU
/vllm/v1/worker/cpu* @bigPYJ1151
Expand Down
34 changes: 14 additions & 20 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,38 +60,32 @@ repos:
files: ^requirements/test\.(in|txt)$
- id: mypy-local
name: Run mypy for local Python installation
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
entry: python tools/pre_commit/mypy.py 0 "local"
stages: [pre-commit] # Don't run in CI
<<: &mypy_common
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
entry: tools/mypy.sh 1 "3.9"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.9"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: tools/mypy.sh 1 "3.10"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.10"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
entry: tools/mypy.sh 1 "3.11"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.11"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
entry: tools/mypy.sh 1 "3.12"
language: python
types: [python]
additional_dependencies: *mypy_deps
entry: python tools/pre_commit/mypy.py 1 "3.12"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: shellcheck
name: Lint shell scripts
Expand Down
21 changes: 0 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,6 @@ ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"

# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from tools/mypy.sh
files = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]

[tool.isort]
skip_glob = [
".buildkite/*",
Expand Down
35 changes: 0 additions & 35 deletions tools/mypy.sh

This file was deleted.

140 changes: 140 additions & 0 deletions tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.

This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.

Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>

Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
"""

import subprocess
import sys
from typing import Optional

import regex as re

FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]

# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",
"vllm/lora",
"vllm/model_executor",
"vllm/plugins",
"vllm/worker",
"vllm/v1",
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
"vllm/model_executor/parallel_utils",
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
]


def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Group changed files into different mypy calls.

Args:
changed_files: List of changed files.

Returns:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
# Skip files which should be ignored completely
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
return file_groups


def mypy(targets: list[str], python_version: Optional[str],
follow_imports: Optional[str], file_group: str) -> int:
"""
Run mypy on the given targets.

Args:
targets: List of files or directories to check.
python_version: Python version to use (e.g., "3.10") or None to use
the default mypy version.
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.

Returns:
The return code from mypy.
"""
args = ["mypy"]
if python_version is not None:
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode


def main():
ci = sys.argv[1] == "1"
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])

if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

returncode = 0
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(changed_files, python_version, follow_imports,
file_group)
return returncode


if __name__ == "__main__":
sys.exit(main())
4 changes: 2 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ def get_metrics(self) -> list["Metric"]:

def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
Expand All @@ -1478,7 +1478,7 @@ def _validate_and_add_requests(
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
prompts = [prompts] # type: ignore[list-item]

num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _validate_and_normalize_truncate_tokens(
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len

if max_length is not None and truncate_prompt_tokens > max_length:
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
Expand Down
9 changes: 5 additions & 4 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,10 @@ async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
encode_fn = partial(self.tokenizer, prompts, **kwargs)
batch_encode_fn = partial(self.tokenizer, prompts,
**kwargs)
results = await self._loop.run_in_executor(
self._executor, encode_fn)
self._executor, batch_encode_fn)

for i, fut in enumerate(result_futures):
if not fut.done():
Expand Down Expand Up @@ -947,7 +948,7 @@ def get_open_port() -> int:

def get_open_ports_list(count: int = 5) -> list[int]:
"""Get a list of open ports."""
ports = set()
ports = set[int]()
while len(ports) < count:
ports.add(get_open_port())
return list(ports)
Expand Down Expand Up @@ -1337,7 +1338,7 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:

def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
obj = [obj]
return [obj] # type: ignore[list-item]
return obj


Expand Down
7 changes: 3 additions & 4 deletions vllm/utils/tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def __init__(
self.dims = dims
self.dynamic_dims = dynamic_dims if dynamic_dims else set()

def resolve(self, **bindings: dict[str,
int]) -> tuple[Union[int, str], ...]:
resolved = []
def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
resolved = list[Union[int, str]]()
for dim in self.dims:
if isinstance(dim, str) and dim in bindings:
resolved.append(bindings[dim])
Expand Down Expand Up @@ -159,7 +158,7 @@ def _validate_tensor_shape_expected(

def validate(self) -> None:
type_hints = get_type_hints(self.__class__, include_extras=True)
shape_env = {}
shape_env = dict[str, int]()

for field_name, field_type in type_hints.items():
# Check if field is missing
Expand Down