-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Quantization] Quark MXFP4 format loading #16943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4318ff2
MXFP4
fxmarty-amd c439a13
Separate moe to another PR
BowenBao edc4980
Fix VLLM_QUARK_EMU_MEM_OPT codepath
BowenBao 2f72aa9
lint
BowenBao e1a9b91
Relax device requirement due to emulation
BowenBao 108a802
Update to comments
BowenBao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# flake8: noqa | ||
"""Tests Quark mxfp4 models against ground truth generation | ||
""" | ||
import pytest | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"] | ||
|
||
EXPECTED_STRS_MAP = { | ||
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ | ||
'\n### Key Features\n\n* **High-throughput Inference**: vLL', | ||
'\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', | ||
'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', | ||
'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', | ||
'\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', | ||
'\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', | ||
'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', | ||
" everybody knows this proverbial saying, but did you know that it's not entirely accurate?", | ||
] | ||
} | ||
|
||
|
||
@pytest.mark.skip(reason="Model to be released in the future") | ||
@pytest.mark.quant_model | ||
@pytest.mark.parametrize("model_name", MODELS) | ||
def test_models(example_prompts, model_name) -> None: | ||
sampling_params = SamplingParams(max_tokens=20, temperature=0) | ||
llm = LLM( | ||
model=model_name, | ||
kv_cache_dtype="fp8", | ||
quantization="quark", | ||
) | ||
outputs = llm.generate(example_prompts, sampling_params) | ||
for i, output in enumerate(outputs): | ||
output_str = output.outputs[0].text | ||
expected_str = EXPECTED_STRS_MAP[model_name][i] | ||
assert expected_str == output_str, ( | ||
f"Expected: {expected_str!r}\nvLLM: {output_str!r}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
vllm/model_executor/layers/quantization/quark/schemes/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .quark_scheme import QuarkScheme | ||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 | ||
from .quark_w8a8_fp8 import QuarkW8A8Fp8 | ||
from .quark_w8a8_int8 import QuarkW8A8Int8 | ||
|
||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] | ||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] |
125 changes: 125 additions & 0 deletions
125
vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import vllm.envs as envs | ||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme | ||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( | ||
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) | ||
from vllm.model_executor.parameter import (GroupQuantScaleParameter, | ||
PackedvLLMParameter) | ||
from vllm.platforms import current_platform | ||
|
||
__all__ = ["QuarkW4A4MXFP4"] | ||
|
||
|
||
class QuarkW4A4MXFP4(QuarkScheme): | ||
|
||
def __init__(self, weight_quant_spec: Dict[str, Any], | ||
input_quant_spec: Dict[str, Any]): | ||
self.out_dtype = torch.get_default_dtype() | ||
self.qscheme = "per_group" | ||
self.weight_quant_spec = weight_quant_spec | ||
self.input_quant_spec = input_quant_spec | ||
self.emulate = not current_platform.supports_mx() | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
return 70 | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
layer.weight = torch.nn.Parameter(layer.weight.data, | ||
requires_grad=False) | ||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, | ||
requires_grad=False) | ||
|
||
if self.emulate: | ||
try: | ||
from quark.torch.export.nn.modules import realquantizer | ||
from quark.torch.quantization.config.config import ( | ||
QuantizationSpec) | ||
except ImportError as err: | ||
raise ImportError( | ||
"The package `amd-quark` is required to use AMD Quark " | ||
"MX-FP4 models. Please install it with `pip install " | ||
"amd-quark`.") from err | ||
|
||
weight_quant_spec = QuantizationSpec.from_dict( | ||
self.weight_quant_spec) | ||
|
||
weight_quantizer = realquantizer.get_real_quantizer( | ||
qspec=weight_quant_spec, | ||
quantizer=None, | ||
real_quantized=True, | ||
reorder=False, | ||
float_dtype=self.out_dtype, | ||
scale_shape=layer.weight_scale.shape, | ||
zero_point_shape=None, | ||
) | ||
weight_quantizer.scale.data = layer.weight_scale.data | ||
|
||
if not envs.VLLM_QUARK_EMU_MEM_OPT: | ||
layer.weight = torch.nn.Parameter( | ||
weight_quantizer(layer.weight.data).to(self.out_dtype), | ||
requires_grad=False, | ||
) | ||
else: | ||
self.weight_quantizer = weight_quantizer | ||
layer.weight_scale = None | ||
|
||
# This call is necessary to release the scales memory. | ||
torch.cuda.empty_cache() | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
output_partition_sizes: List[int], | ||
input_size_per_partition: int, | ||
params_dtype: torch.dtype, weight_loader: Callable, | ||
**kwargs): | ||
output_size_per_partition = sum(output_partition_sizes) | ||
layer.logical_widths = output_partition_sizes | ||
|
||
# WEIGHT | ||
weight = PackedvLLMParameter( | ||
data=torch.empty( | ||
output_size_per_partition, | ||
input_size_per_partition // 2, | ||
dtype=torch.uint8, | ||
), | ||
input_dim=1, | ||
output_dim=0, | ||
packed_dim=1, | ||
packed_factor=2, | ||
weight_loader=weight_loader, | ||
) | ||
layer.register_parameter("weight", weight) | ||
|
||
# WEIGHT SCALE | ||
weight_scale = GroupQuantScaleParameter( | ||
data=torch.empty( | ||
output_size_per_partition, | ||
input_size_per_partition // OCP_MX_BLOCK_SIZE, | ||
dtype=torch.uint8, | ||
), | ||
input_dim=1, | ||
output_dim=0, | ||
weight_loader=weight_loader, | ||
) | ||
layer.register_parameter("weight_scale", weight_scale) | ||
|
||
def apply_weights(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
|
||
if self.emulate: | ||
if envs.VLLM_QUARK_EMU_MEM_OPT: | ||
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) | ||
else: | ||
dq_w = layer.weight | ||
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) | ||
return F.linear(qdq_x, dq_w, bias) | ||
else: | ||
raise NotImplementedError() |
45 changes: 45 additions & 0 deletions
45
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
OCP_MX_BLOCK_SIZE = 32 | ||
|
||
|
||
def per_token_group_quant_mxfp4(x: torch.Tensor, | ||
block_k: int, | ||
scale_calculation_mode: str = "even" | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
try: | ||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( | ||
fake_quantize_fp4_fp6_per_group_with_scale) | ||
from quark.torch.quantization.utils import (even_round, | ||
reshape_to_blocks) | ||
except ImportError as err: | ||
raise ImportError("The package `amd-quark` is required to use " | ||
"MX-FP4 models. Please install it with `pip install " | ||
"amd-quark`.") from err | ||
|
||
axis = -1 | ||
block_x = reshape_to_blocks(x, block_k, axis) | ||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) | ||
amax = amax.squeeze(-1) | ||
|
||
# TODO: there are other rounding strategies supported in quark and in the | ||
# config.json that we do not check for here! | ||
if scale_calculation_mode != "even": | ||
raise NotImplementedError( | ||
f"Scale calculation mode {scale_calculation_mode} is not yet " | ||
"supported in MX-FP4 quantization") | ||
scale = even_round(amax, "fp4") | ||
|
||
# Apply dequantize(quantize(x)). | ||
x = fake_quantize_fp4_fp6_per_group_with_scale( | ||
x, | ||
scale.to(x.device), | ||
axis=axis, | ||
group_size=block_k, | ||
quant_dtype="fp4", | ||
) | ||
|
||
return x, scale |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the env var and always do weight decompress at runtime? This is the expected behavior from other quantization methods so I feel it is strange to not do compression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We found it to be more efficient for emulation evaluations doing aot weight dequant. That being said, this can be removed with the support of more efficient dequant kernels. I would prefer keeping this option for now but let me know if you feel strongly about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay we can keep it for now, but let us hope to remove over time. We want to try and keep the list from ever-growing unless there is a good reason