diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 41d351712961..31f6ec3e7321 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from ...utils import is_accelerate_available +from ...utils import is_accelerate_available, is_kernels_available if is_accelerate_available(): @@ -29,6 +29,78 @@ from accelerate.hooks import add_hook_to_module, remove_hook_from_module +can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 +if can_use_cuda_kernels and is_kernels_available(): + from kernels import get_kernel + + ops = get_kernel("Isotr0py/ggml") +else: + ops = None + +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} +STANDARD_QUANT_TYPES = { + gguf.GGMLQuantizationType.Q4_0, + gguf.GGMLQuantizationType.Q4_1, + gguf.GGMLQuantizationType.Q5_0, + gguf.GGMLQuantizationType.Q5_1, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q8_1, +} +KQUANT_TYPES = { + gguf.GGMLQuantizationType.Q2_K, + gguf.GGMLQuantizationType.Q3_K, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q5_K, + gguf.GGMLQuantizationType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + gguf.GGMLQuantizationType.IQ1_M, + gguf.GGMLQuantizationType.IQ1_S, + gguf.GGMLQuantizationType.IQ2_XXS, + gguf.GGMLQuantizationType.IQ2_XS, + gguf.GGMLQuantizationType.IQ2_S, + gguf.GGMLQuantizationType.IQ3_XXS, + gguf.GGMLQuantizationType.IQ3_S, + gguf.GGMLQuantizationType.IQ4_XS, + gguf.GGMLQuantizationType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for + # contiguous batching and inefficient with diffusers' batching, + # so we disabled it now. + + # elif qweight_type in MMVQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # elif qweight_type in MMQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + + # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + y = x @ weight.T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = gguf.GGMLQuantizationType(qweight_type) + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") + return y + + # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook def _create_accelerate_new_hook(old_hook): r""" @@ -451,11 +523,24 @@ def __init__( ) -> None: super().__init__(in_features, out_features, bias, device) self.compute_dtype = compute_dtype + self.device = device + + def forward(self, inputs: torch.Tensor): + if ops is not None and self.weight.is_cuda and inputs.is_cuda: + return self.forward_cuda(inputs) + return self.forward_native(inputs) - def forward(self, inputs): + def forward_native(self, inputs: torch.Tensor): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) if self.bias is not None else None output = torch.nn.functional.linear(inputs, weight, bias) return output + + def forward_cuda(self, inputs: torch.Tensor): + quant_type = self.weight.quant_type + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + if self.bias is not None: + output += self.bias.to(self.compute_dtype) + return output diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..72b12badf269 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -78,6 +78,7 @@ is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, + is_kernels_available, is_librosa_available, is_matplotlib_available, is_nltk_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f12e9de33172..6174d5b72c32 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") +_kernels_available, _kernels_version = _is_package_available("kernels") _inflect_available, _inflect_version = _is_package_available("inflect") _unidecode_available, _unidecode_version = _is_package_available("unidecode") _k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") @@ -274,6 +275,10 @@ def is_accelerate_available(): return _accelerate_available +def is_kernels_available(): + return _kernels_available + + def is_k_diffusion_available(): return _k_diffusion_available