Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down Expand Up @@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

model = SillyModel(vllm_config=VllmConfig(), prefix='')
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')

inputs = torch.randn(100).cuda()

Expand Down
22 changes: 12 additions & 10 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down Expand Up @@ -272,9 +270,11 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down Expand Up @@ -395,9 +395,11 @@ def benchmark():
else:
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)

B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Expand Down
4 changes: 3 additions & 1 deletion tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel


class MyMod(torch.nn.Module):
Expand All @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.platforms import current_platform

TEST_MODELS = [
Expand Down
52 changes: 26 additions & 26 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest

from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config


# Registered subclass for test
Expand Down Expand Up @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]

assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]

ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()


@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()

with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()
2 changes: 1 addition & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import depyf

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
Expand Down
2 changes: 1 addition & 1 deletion tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import compare_two_settings

Expand Down
20 changes: 13 additions & 7 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import torch.fx as fx

import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors

from .config import CompilationConfig
from .counter import compilation_counter
from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass

logger = init_logger(__name__)
Expand Down Expand Up @@ -392,7 +391,10 @@ class VllmBackend:
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]

def __init__(self, post_grad_passes: Sequence[Callable] = ()):
def __init__(
self,
compilation_configs: CompilationConfig,
):
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
Expand All @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()):
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes
self.post_grad_passes = []

self.sym_tensor_indices = []
self.input_buffers = []

self.compilation_configs = compilation_configs

# `torch.compile` is JIT compiled, so we don't need to
# do anything here

Expand Down Expand Up @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is read now, because only here can
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config()
self.compilation_configs.init_during_runtime()
self.add_passes_to_config()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down Expand Up @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
return backend_str
assert level == CompilationLevel.PIECEWISE

return VllmBackend()
from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)
Loading