Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/models/quantization/test_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Quark mxfp4 models against ground truth generation
"""
import pytest

from vllm import LLM, SamplingParams

MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"]

EXPECTED_STRS_MAP = {
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [
'\n### Key Features\n\n* **High-throughput Inference**: vLL',
'\nArtificial intelligence (AI) has evolved significantly since its inception in the 1',
'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been',
'A neural network is a machine learning model inspired by the structure of the human brain. It consists of',
'\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol',
'\nThe COVID-19 pandemic has had a profound impact on global economic structures and business',
'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th',
" everybody knows this proverbial saying, but did you know that it's not entirely accurate?",
]
}


@pytest.mark.skip(reason="Model to be released in the future")
@pytest.mark.quant_model
@pytest.mark.parametrize("model_name", MODELS)
def test_models(example_prompts, model_name) -> None:
sampling_params = SamplingParams(max_tokens=20, temperature=0)
llm = LLM(
model=model_name,
kv_cache_dtype="fp8",
quantization="quark",
)
outputs = llm.generate(example_prompts, sampling_params)
for i, output in enumerate(outputs):
output_str = output.outputs[0].text
expected_str = EXPECTED_STRS_MAP[model_name][i]
assert expected_str == output_str, (
f"Expected: {expected_str!r}\nvLLM: {output_str!r}")
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_QUARK_EMU_MEM_OPT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Expand Down Expand Up @@ -583,6 +584,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),

# If set, when running in Quark emulation mode, do not dequantize the
# weights at load time. Instead, dequantize weights on-the-fly during
# kernel execution.
# This allows running larger models at the cost of slower inference.
# This flag has no effect when not running in Quark emulation mode.
"VLLM_QUARK_EMU_MEM_OPT":
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),

# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
Expand Down
56 changes: 55 additions & 1 deletion vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
Expand All @@ -15,13 +16,15 @@
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkMoEMethod)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.platforms import current_platform

__all__ = ["QuarkLinearMethod"]

logger = init_logger(__name__)


class QuarkConfig(QuantizationConfig):

Expand Down Expand Up @@ -67,6 +70,7 @@ def get_quant_method(self, layer: torch.nn.Module,
return QuarkLinearMethod(self)
if isinstance(layer, Attention):
return QuarkKVCacheMethod(self)

if isinstance(layer, FusedMoE):
return QuarkMoEMethod.get_moe_method(self,
module=layer,
Expand Down Expand Up @@ -205,6 +209,54 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static

def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug("Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set")
return False

# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get(
"dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False

# Input and weight qscheme needs to be per group.
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
"qscheme") != "per_group":
logger.debug("Quark model is not in MX-FP4 format: not per_group")
return False

# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get(
"group_size") != 32:
logger.debug(
"Quark model is not in MX-FP4 format: not group_size=32")
return False

# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug(
"Quark model is not in MX-FP4 format: not weight static")
return False

# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug(
"Quark model is not in MX-FP4 format: not activation dynamic")
return False

# Activations and weight scales need to be in e8m0 format.
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
"scale_format") != "e8m0":
logger.debug(
"Quark model is not in MX-FP4 format: not scale_format e8m0")
return False

return True

def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:

Expand Down Expand Up @@ -269,6 +321,8 @@ def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
return QuarkW8A8Int8(qscheme=weight_qscheme,
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"))
elif self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)

raise NotImplementedError("No quark compatible scheme was found. "
f"Weight config: {weight_config}, "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

from .quark_scheme import QuarkScheme
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8

__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F

import vllm.envs as envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform

__all__ = ["QuarkW4A4MXFP4"]


class QuarkW4A4MXFP4(QuarkScheme):

def __init__(self, weight_quant_spec: Dict[str, Any],
input_quant_spec: Dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()

@classmethod
def get_min_capability(cls) -> int:
return 70

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)

if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err

weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)

weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data

if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
Comment on lines +64 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the env var and always do weight decompress at runtime? This is the expected behavior from other quantization methods so I feel it is strange to not do compression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We found it to be more efficient for emulation evaluations doing aot weight dequant. That being said, this can be removed with the support of more efficient dequant kernels. I would prefer keeping this option for now but let me know if you feel strongly about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay we can keep it for now, but let us hope to remove over time. We want to try and keep the list from ever-growing unless there is a good reason

layer.weight_scale = None

# This call is necessary to release the scales memory.
torch.cuda.empty_cache()

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
else:
raise NotImplementedError()
45 changes: 45 additions & 0 deletions vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple

import torch

OCP_MX_BLOCK_SIZE = 32


def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> Tuple[torch.Tensor, torch.Tensor]:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err

axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)

# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")

# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
)

return x, scale
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
]

if (model_config.quantization is not None
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ def get_device_communicator_cls(cls) -> str:
"""
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa

@classmethod
def supports_mx(cls) -> bool:
"""
Returns whether the current platform supports MX types.
"""
return False

@classmethod
def supports_fp8(cls) -> bool:
"""
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ def get_current_memory_usage(cls,
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa

@classmethod
def supports_mx(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx95"])

@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
Expand Down