Skip to content

Add cuda kernel support for GGUF inference #11869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn as nn

from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available


if is_accelerate_available():
Expand All @@ -29,6 +29,78 @@
from accelerate.hooks import add_hook_to_module, remove_hook_from_module


can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel

ops = get_kernel("Isotr0py/ggml")
else:
ops = None

UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T

# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.

# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])

# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y


# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
Expand Down Expand Up @@ -451,11 +523,24 @@ def __init__(
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device

def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
Comment on lines +528 to +531
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fairly safe as long as we get the same values (upto some tolerance) on both the native and kernels' variants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the kernels' variants have been tested in vLLM's kernel tests CI compared with numpy implemented dequantize implementation. (https://github.com/vllm-project/vllm/blob/110df74332785ee749af47c5a3eb634d216b8f3b/tests/kernels/quantization/test_gguf.py#L67-L83)


def forward(self, inputs):
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None

output = torch.nn.functional.linear(inputs, weight, bias)
return output

def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
Expand Down Expand Up @@ -274,6 +275,10 @@ def is_accelerate_available():
return _accelerate_available


def is_kernels_available():
return _kernels_available


def is_k_diffusion_available():
return _k_diffusion_available

Expand Down
Loading