Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):

def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"):
vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
Expand Down Expand Up @@ -359,7 +358,6 @@ def call_module(self, target: torch.fx.node.Target,
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
Expand Down Expand Up @@ -405,7 +403,6 @@ class VllmBackend:

vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
Expand Down Expand Up @@ -433,13 +430,6 @@ def __init__(
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

global_graph_pool = current_platform.get_global_graph_pool()

# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool

# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()

Expand Down Expand Up @@ -586,7 +576,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config, self.graph_pool,
self.vllm_config,
self).run(*example_inputs)

graph_path = os.path.join(local_cache_dir, "computation_graph.py")
Expand Down
5 changes: 1 addition & 4 deletions vllm/compilation/base_static_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""

def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
runtime_mode: CUDAGraphMode, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Expand All @@ -25,9 +25,6 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig,
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
Expand Down
8 changes: 4 additions & 4 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config

Expand All @@ -81,8 +79,10 @@ def __init__(self,
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()

if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
Expand Down