Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)


def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

Defaults to `True` but is disabled in the following situations:

- The model uses dynamic rope scaling.
"""
enable = True
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
if rope_scaling.get("rope_type") == "dynamic":
enable = False
return enable


def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig
Expand Down Expand Up @@ -641,7 +658,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


@support_torch_compile
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
Expand All @@ -653,7 +670,7 @@ class TransformersModel(TransformersBase):
})


@support_torch_compile
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -709,12 +726,14 @@ def _can_concat(x: list[torch.Tensor]):
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}) # set `positions` to last dim to support Qwen-mrope
},
enable_if=can_enable_torch_compile)
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
Expand Down