diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index ed0a37b4f86..34f8749b40f 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -418,7 +418,7 @@ python -m examples.models.llama.export_llama \ ``` A few notes: -- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations. +- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized symmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations. - To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers. Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels. diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index ec02f442217..1ed6e0d280b 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -18,6 +18,15 @@ from sentencepiece import SentencePieceProcessor +from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + MappingType, + quantize_, +) + try: from fairseq2.nn.embedding import ( @@ -118,15 +127,6 @@ def quantize( # noqa C901 assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" bitwidth = int(matches[0][0]) - from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout - from torchao.quantization.granularity import PerAxis, PerGroup - from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig, - MappingType, - quantize_, - ) - from torchao.utils import unwrap_tensor_subclass - with torch.no_grad(): # Computation dtype is fixed to fp32 in the implementation of quantize_, so # no way to decouple checkpoint and computation dtype. @@ -141,7 +141,6 @@ def quantize( # noqa C901 layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ), ) - model = unwrap_tensor_subclass(model) if verbose: print("quantized model:", model) return model @@ -150,14 +149,17 @@ def quantize( # noqa C901 if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ - from torchao.utils import unwrap_tensor_subclass - - quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) - model = unwrap_tensor_subclass(model) - + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=( + PerAxis(0) if group_size == 0 else PerGroup(group_size) + ), + weight_mapping_type=MappingType.SYMMETRIC, + ), + ) # TODO: deal with checkpoint / computation dtype decoupling. - if verbose: print("quantized model:", model) return model @@ -563,254 +565,32 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) -######################################################################### -##### embedding table quantization ###### - - -def replace_embedding_weight_only_grouped_int8_per_channel( - module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False -): - for name, child in module.named_children(): - # print(f"name: {name}") - if isinstance(child, nn.Embedding): - # print(f"{name, child}") - # print(f"weights size: {child.weight.size()}") - setattr( - module, - name, - QuantizedGroupEmbedding( - device=device, - vocab_size=child.weight.shape[0], - embedding_dim=child.weight.shape[1], - group_size=group_size, - dtype=child.weight.dtype, - packed=packed, - bitwidth=bitwidth, - ), - ) - else: - replace_embedding_weight_only_grouped_int8_per_channel( - child, device, bitwidth, group_size, packed - ) - - -class EmbeddingQuantHandler(QuantHandler): - def __init__( - self, - mod, - device="cpu", - *, - bitwidth: int = 8, - group_size: Optional[int] = None, - packed=False, - precision: Optional[torch.dtype] = None, - ): - if isinstance(packed, str): - packed = packed == "True" - self.mod = mod - self.device = device - self.group_size = group_size - self.bitwidth = bitwidth - self.packed = packed - # Dtype of the weights right before quantization. - self.precision = precision - if (bitwidth not in [2, 4]) and packed: - raise RuntimeError("pack only works with bitsize 2, 4") - - @torch.no_grad() - def create_quantized_state_dict(self, packed=False) -> Dict: - cur_state_dict = self.mod.state_dict() - - if self.bitwidth == 2: - range_min = -2 - range_max = 1 - elif self.bitwidth == 4: - range_min = -8 - range_max = 7 - elif self.bitwidth == 8: - range_min = -128 - range_max = 127 - else: - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") - - for fqn, mod in self.mod.named_modules(): - if isinstance(mod, nn.Embedding): - # print("****") - # print(f"Embedding identified: {fqn, mod}") - # print(f"weights size: {mod.weight.size()}") - # print(f"quantize {fqn}...") - - print( - f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" - ) - weight, scales, _ = dynamically_quantize_per_channel( - ( - mod.weight.to(dtype=self.precision) - if self.precision - else mod.weight - ), - range_min, - range_max, - torch.int8, - self.group_size, - scales_dtype=mod.weight.dtype, - ) - - if packed: - if self.bitwidth == 2: - if weight.shape[-1] % 4 != 0: - raise RuntimeError("automatic padding not implemented yet") - weight_range_shifted = weight.add(2).view(torch.uint8) - weight_view = weight_range_shifted.view( - weight.shape[0], weight.shape[1] // 4, 4 - ) - weight_0 = weight_view[:, :, 0] - weight_1 = weight_view[:, :, 1] << 2 - weight_2 = weight_view[:, :, 2] << 4 - weight_3 = weight_view[:, :, 3] << 6 - weight_packed = weight_0 + weight_1 + weight_2 + weight_3 - weight = weight_packed - elif self.bitwidth == 4: - if weight.shape[-1] % 2 != 0: - raise RuntimeError("automatic padding not implemented yet") - weight_range_shifted = weight.add(8).view(torch.uint8) - weight_view = weight_range_shifted.view( - weight.shape[0], weight.shape[1] // 2, 2 - ) - weight_even = weight_view[:, :, 0] * 16 # left shift 4 - weight_odd = weight_view[:, :, 1] - weight_packed = weight_even + weight_odd - weight = weight_packed - - weight = weight.to(device=self.device) - scales = scales.to(device=self.device) - # Update state dict - cur_state_dict[f"{fqn}.weight"] = weight - # squeeze makes group_size=rowsize unidimensional - cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) - - return cur_state_dict - - def convert_for_runtime(self) -> nn.Module: - replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.device, self.bitwidth, self.group_size, self.packed - ) - return self.mod - - def quantized_model(self) -> nn.Module: - model_updated_state_dict = self.create_quantized_state_dict(self.packed) - self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict, assign=True) - return self.mod - - -class QuantizedGroupEmbedding(torch.nn.Module): - def __init__( - self, - device, - vocab_size: int, - embedding_dim: int, - group_size: Optional[int] = None, - dtype=torch.half, - packed=False, - bitwidth: int = 8, - ) -> None: - super().__init__() - if group_size is None or group_size == 0: - group_size = embedding_dim - self.group_size = group_size - self.dtype = dtype - self.packed = packed - self.bitwidth = bitwidth - if not packed: - self.register_buffer( - "weight", - torch.zeros( - (vocab_size, embedding_dim), dtype=torch.int8, device=device - ), - ) - else: # packed - if bitwidth == 2: - self.register_buffer( - "weight", - torch.zeros( - (vocab_size, embedding_dim // 4), - dtype=torch.uint8, - device=device, - ), - ) - elif bitwidth == 4: - self.register_buffer( - "weight", - torch.zeros( - (vocab_size, embedding_dim // 2), - dtype=torch.uint8, - device=device, - ), - ) - - groups_per_row = (embedding_dim + group_size - 1) // group_size - if groups_per_row > 1: - self.register_buffer( - "scales", - torch.ones( - (vocab_size, groups_per_row), dtype=torch.float16, device=device - ), - ) - else: - self.register_buffer( - "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) - ) - - @torch.no_grad() - def forward(self, indices: torch.Tensor) -> torch.Tensor: - if not self.packed: # 8bit - return torch.ops.quantized_decomposed.embedding_byte.dtype( - self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype - ) - else: # packed - if self.bitwidth == 2: - return torch.ops.quantized_decomposed.embedding_2bit.dtype( - self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype - ) +############################ Source Transform Start ####################### - # Remaining case (always return to make pyre happy) - assert self.bitwidth == 4 - return torch.ops.quantized_decomposed.embedding_4bit.dtype( - self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype - ) +def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): + use_torchao = args.embedding_quantize.startswith("torchao:") + if use_torchao: + quant_args = args.embedding_quantize.split(":")[1].split(",") + else: + quant_args = args.embedding_quantize.split(",") -############################ Source Transform Start ####################### + bitwidth = int(quant_args[0]) + group_size = quant_args[0] + if group_size in ["none", "None", "0"]: + group_size = 0 + group_size = int(group_size) + is_symmetric = bool(quant_args[3]) if len(quant_args) > 2 else True + weight_dtype = getattr(torch, f"int{bitwidth}") + granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size) + mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC -def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): - if args.embedding_quantize.startswith("torchao:"): + if use_torchao: from torchao.experimental.quant_api import ( EmbeddingQuantizer, SharedEmbeddingQuantizer, ) - from torchao.quantization.granularity import PerAxis, PerGroup - from torchao.quantization.quant_api import MappingType - - quant_args = args.embedding_quantize.split(":")[1].split(",") - if len(quant_args) == 2: - bitwidth, group_size = quant_args - is_asymmetric = True - else: - bitwidth, group_size, is_asymmetric = quant_args - - if group_size in ["none", "None", "0"]: - group_size = 0 - - group_size = int(group_size) - bitwidth = int(bitwidth) - is_asymmetric = bool(is_asymmetric) - weight_dtype = getattr(torch, f"int{bitwidth}") - granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size) - mapping_type = ( - MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC - ) def _torchao_embedding_quantizer(model): with torch.no_grad(): @@ -831,20 +611,23 @@ def _torchao_embedding_quantizer(model): return _torchao_embedding_quantizer - bitwidth, group_size = args.embedding_quantize.split(",") - if group_size == "none" or group_size == "None" or group_size == "0": - group_size = None - else: - group_size = int(group_size) - bitwidth = int(bitwidth) - torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None - return lambda model: EmbeddingQuantHandler( - model, - bitwidth=bitwidth, - group_size=group_size, - packed=(bitwidth in [2, 4]), - precision=torch_dtype, - ).quantized_model() + def _quantize_embedding(model): + assert weight_dtype in [ + torch.int2, + torch.int4, + torch.int8, + ], "Only 2, 4, or 8-bit embeddings are supported unless using torchao" + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=mapping_type, + ), + ) + return model + + return _quantize_embedding def get_quant_weight_transform(