diff --git a/auto_round/__main__.py b/auto_round/__main__.py index ab4da0b68..264524b75 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -756,7 +756,11 @@ def tune(args): suffix = f"a{autoround.act_bits}" else: suffix = f"g{autoround.group_size}" - prefix = autoround.data_type.lower().replace("_", "") if "int" not in autoround.data_type else "" + prefix = ( + autoround.data_type.lower().replace("_", "") + if "int" not in autoround.data_type or "mx" in autoround.data_type + else "" + ) export_dir = os.path.join( args.output_dir, model_name.split("/")[-1] + (f"-{prefix}" if prefix else "") + f"-w{autoround.bits}{suffix}", diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 58f64f683..0b9bb1599 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -35,6 +35,7 @@ class BackendDataType(str, Enum): STANDARD_FP = "fp" MX_FP = "mx_fp" NV_FP = "nv_fp" + MX_INT = "mx_int" def is_standard_fp(backend): @@ -47,6 +48,11 @@ def is_mx_fp(backend): return BackendDataType.MX_FP in backend +def is_mx_int(backend): + backend = backend.lower() + return BackendDataType.MX_INT in backend + + def is_nv_fp(backend): backend = backend.lower() return BackendDataType.NV_FP in backend diff --git a/auto_round/experimental/qmodules/__init__.py b/auto_round/experimental/qmodules/__init__.py index 3862e0293..df20d4afa 100644 --- a/auto_round/experimental/qmodules/__init__.py +++ b/auto_round/experimental/qmodules/__init__.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, HadamardMXFP4QuantLinear +from auto_round.experimental.qmodules.mx import ( + MXFP4QuantLinear, + MXFP8QuantLinear, + MXINT4QuantLinear, + HadamardMXFP4QuantLinear, +) from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear diff --git a/auto_round/experimental/qmodules/int4_utils.py b/auto_round/experimental/qmodules/int4_utils.py new file mode 100644 index 000000000..eeed82e5e --- /dev/null +++ b/auto_round/experimental/qmodules/int4_utils.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +_DEVICE_E0M4_TENSORS = {} + +# Constants for INT4 values +_E0M4_VALUES = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75] + + +def get_e0m4_tensor(device): + """Get device-specific E0M4 lookup tensor, creating it if needed.""" + device_str = str(device) + if device_str not in _DEVICE_E0M4_TENSORS: + _DEVICE_E0M4_TENSORS[device_str] = torch.tensor(_E0M4_VALUES, dtype=torch.float32, device=device) + return _DEVICE_E0M4_TENSORS[device_str] + + +def unpack_int4_from_uint8( + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 +) -> torch.Tensor: + """ + Unpacks uint8 values into int4. Each uint8 contains two int4 values + (low nibble first). The 4-bit indices are mapped to int4 values using kE0M4ToFloat. + """ + if a.device.type == "cuda": + return _unpack_int4_from_uint8_cuda(a, m, n, dtype) + else: + return _unpack_int4_from_uint8_cpu(a, m, n, dtype) + + +@torch.compiler.disable() +def _unpack_int4_from_uint8_cpu( + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 +) -> torch.Tensor: + return _unpack_int4_from_uint8(a, m, n, dtype) + + +# @torch.compile(fullgraph=True, dynamic=True) +def _unpack_int4_from_uint8_cuda( + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 +) -> torch.Tensor: + return _unpack_int4_from_uint8(a, m, n, dtype) + + +# reference: : https://github.com/vllm-project/vllm/pull/16362 +def _unpack_int4_from_uint8( + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 +) -> torch.Tensor: + """ + Unpacks uint8 values into int4. Each uint8 consists of two int4 values + (i.e. first four bits correspond to one int4 value, last four correspond to a + consecutive int4 value). The bits represent an index, which are mapped to an int4 + value. + + :param a: tensor to unpack + :param m: original dim 0 size of the unpacked tensor + :param n: original dim 1 size of the unpacked tensor + :param dtype: dense dtype to cast the unpacked tensor to + """ + assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}" + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE0M4 = get_e0m4_tensor(device=a.device) + values = kE0M4[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n).to(dtype=dtype) diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index b5bc3e939..449e5f348 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -20,10 +20,11 @@ from auto_round.data_type.utils import get_quant_func from auto_round.experimental.qmodules.base import QModuleBase from auto_round.experimental.qmodules.fp4_utils import unpack_fp4_from_uint8 +from auto_round.experimental.qmodules.int4_utils import unpack_int4_from_uint8 from auto_round.logger import logger from auto_round.schemes import QuantizationScheme -__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear"] +__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear", "MXINT4QuantLinear"] SUPPORTED_HIGHER_DTYPE = [torch.bfloat16, torch.float16, torch.float32] E8M0_EXPONENT_BIAS = 127 @@ -196,6 +197,47 @@ def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: return unpacked_data +class MXINT4QuantLinear(MXQuantLinearBase): + """ + Quantized linear layer using the MXINT4 quantization scheme. + """ + + def __init__(self, *args, **kwargs): + self.weight_name = "weight_packed" + super().__init__(*args, **kwargs) + + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: + weight_dtype = torch.uint8 + weight_in_features = self.in_features // 2 + return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight + + def dequant_weight_online(self) -> torch.Tensor: + if self.pre_dequantized: + return self.weight + dq_weight = self.dequant_mx_tensor(self.weight_packed, self.weight_scale) + return dq_weight + + def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: + m, half_n = packed_data.shape + unpacked_data = unpack_int4_from_uint8(packed_data, m, half_n * 2, dtype=self.dtype) + return unpacked_data + + @classmethod + def from_original(cls, config: Optional[QuantizationScheme], original_layer: torch.nn.Linear): + """ + Create an `MXQuantLinear` layer from an original linear layer. + """ + logger.warning_once("MXINT quantization is still in experimental stage, the inference speed might be slow.") + qdq_linear = cls( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + config=config, + bias=original_layer.bias, + dtype=original_layer.weight.dtype, + ) + return qdq_linear + + class HadamardMXFP4QuantLinear(MXFP4QuantLinear): """ Quantized linear layer using the MXFP4 quantization scheme. diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mx.py similarity index 97% rename from auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py rename to auto_round/export/export_to_autoround/export_to_nvfp_mx.py index 502c49676..81ffb3e51 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mx.py @@ -44,7 +44,8 @@ ) from auto_round.wrapper import WrapperWALayer -from .qlinear_fp import QuantLinear +from .qlinear_fp import QuantLinear as FpQuantLinear +from .qlinear_int import QuantLinear as IntQuantLinear __all__ = [ "pack_layer", @@ -94,7 +95,8 @@ def pack_layer(name, model, backend, device=None): bias = layer.bias is not None ##bias = True ## if using the above, llama3 lambada RTN will be NAN , TODO why? - qlayer = QuantLinear( ##pylint: disable=E1123 + linear_func = FpQuantLinear if "fp" in data_type else IntQuantLinear + qlayer = linear_func( ##pylint: disable=E1123 bits, group_size, in_features, diff --git a/auto_round/export/export_to_autoround/qlinear_int.py b/auto_round/export/export_to_autoround/qlinear_int.py new file mode 100644 index 000000000..62c730410 --- /dev/null +++ b/auto_round/export/export_to_autoround/qlinear_int.py @@ -0,0 +1,202 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import transformers + +import auto_round.envs as envs +from auto_round.compressors.utils import BackendDataType +from auto_round.data_type.mxfp import FP32_EXPONENT_BIAS, FP32_MIN_NORMAL +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad +from auto_round.utils import get_packing_device, logger + +# from auto_round.utils import get_weight_compress_dtype +E8M0_EXPONENT_BIAS = 127 +E8M0_EXPONENT_NAN_VAL = 255 + +__all__ = ["QuantLinear"] + +FLOAT_TO_E0M4 = [ + 0.0, + 0.25, + 0.5, + 0.75, + 1.0, + 1.25, + 1.5, + 1.75, +] + + +class QuantLinear(nn.Module): + """ + MXFP quantized linear layer. + """ + + QUANT_TYPE = "MXINT" + + def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, data_type="mx_int4", **kwargs): + super().__init__() + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + if group_size != 32: + raise NotImplementedError(f"Only group_size 32 are supported for {BackendDataType.MX_INT} data type.") + if infeatures % group_size != 0: + raise NotImplementedError( + f"in_feature must be divisible by {group_size} for {BackendDataType.MX_INT} data type." + ) + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.data_type = data_type + self.sym = kwargs.get("sym", True) + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + self.act_bits = kwargs.get("act_bits", None) + + weight_name = "weight_packed" + weight_infeatures = infeatures if self.bits == 8 else infeatures // 2 + weight_dtype = torch.uint8 + ## TODO check the dtype of weight_packed and weight_scale + self.register_buffer( + weight_name, + torch.zeros((outfeatures, weight_infeatures), dtype=weight_dtype), + ) + self.register_buffer( + "weight_scale", + torch.zeros( + (outfeatures, math.ceil(infeatures / self.group_size)), + dtype=torch.float16, ## TODO update to correct scale dtype for different bits + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.trainable = trainable + + def post_init(self): + pass + + def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_global_scale=None, device=None): + device = get_packing_device(device) + if getattr(linear, "bias", None) is not None: + self.bias = linear.bias.detach().to(torch.float16) + + W = linear.weight.data.detach().to(device) + if type(linear) == nn.Conv2d: + W = W.flatten(1) + if type(linear) == transformers.pytorch_utils.Conv1D: + W = W.t() + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size) + scales = scales.to(device) + scaled_tensor = tensor / (2 ** scales.reshape(tensor.shape[0], -1)) + scaled_tensor = revert_tensor_by_pad(scaled_tensor, orig_shape=orig_shape, pad_len=pad_len) + final_scale = (scales + E8M0_EXPONENT_BIAS).clamp(0, E8M0_EXPONENT_NAN_VAL).to(torch.uint8) + + self.weight_scale = final_scale + compress_dtype = torch.uint8 + self.weight_packed = pack_int4_to_uint8(scaled_tensor) + + +def pack_int4_to_uint8(scaled_tensor: torch.Tensor): + if scaled_tensor.device.type == "cuda": + return pack_int4_to_uint8_cuda(scaled_tensor) + else: + return pack_int4_to_uint8_cpu(scaled_tensor) + + +# The torch.compile with dynamic=True is incompatible with multiple threads +# https://github.com/pytorch/pytorch/issues/126024 +@torch.compiler.disable() +def pack_int4_to_uint8_cpu(x: torch.Tensor) -> torch.Tensor: + return _pack_int4_to_uint8(x) + + +# Adapted from https://github.com/neuralmagic/compressed-tensors/pull/400 + + +def _get_packing_fn(): + if envs.AR_ENABLE_COMPILE_PACKING: + logger.warning_once( + "Compiled INT4 to UINT8 packing may be incompatible with multi-threading." + " Disable it by setting AR_ENABLE_COMPILE_PACKING=0" + ) + return torch.compile(fullgraph=True, dynamic=True)(_pack_int4_to_uint8) + else: + return torch.compiler.disable()(_pack_int4_to_uint8) + + +def pack_int4_to_uint8_cuda(x: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor with values in the int4 range into uint8. + + :param x: tensor to pack + returns: a packed tensor in uint8 + """ + pack_fn = _get_packing_fn() + return pack_fn(x) + + +def _pack_int4_to_uint8(x: torch.Tensor) -> torch.Tensor: + + m, n = x.shape + device = x.device + + # Create lookup table for INT4 values to indices + # Map the absolute values to 0-7 indices + kE0M4 = torch.tensor(FLOAT_TO_E0M4, device=device, dtype=x.dtype) + + # Find closest valid INT4 value index for each element + abs_x = torch.abs(x) + abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE0M4) # [m, n, 8] + abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n] + + # Apply sign bit (bit 3) to get final 4-bit representation + indices = abs_indices + (torch.signbit(x).to(torch.long) << 3) + + # Reshape to prepare for packing pairs of values + indices = indices.reshape(-1) + + # Handle odd length by padding if necessary + if indices.numel() % 2 != 0: + indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) + + # Reshape to pair consecutive elements + indices = indices.reshape(-1, 2) + + # Pack pairs of 4-bit values into 8-bit values + packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) + + return packed.reshape(m, n // 2) diff --git a/auto_round/formats.py b/auto_round/formats.py index 03213cef6..60df0131a 100644 --- a/auto_round/formats.py +++ b/auto_round/formats.py @@ -31,6 +31,7 @@ is_dynamic_afp8, is_dynamic_wint8aint8, is_mx_fp, + is_mx_int, is_nv_fp, is_standard_fp, is_static_wfp8afp8, @@ -69,6 +70,8 @@ class AutoRoundExportFormat(str, Enum): NV_FP4_WITH_STATIC_GS = "nv_fp4_with_static_gs" INT8_W8A8 = "int8_w8a8" FP8_BLOCK = "fp8_block" + MXINT4 = "mxint4" + MX_INT = "mx_int" if TYPE_CHECKING: @@ -1077,6 +1080,7 @@ class AutoRoundFormat(OutputFormat): "FP8_STATIC", "BF16", "FP8_BLOCK", + "MXINT4", ] format_name = "auto_round" @@ -1085,7 +1089,7 @@ def __init__(self, format: str, ar: BaseCompressor): self.backend = None if format == "auto_round": - if ar.sym and "int" in ar.data_type: + if ar.sym and "int" in ar.data_type and "mx" not in ar.data_type: self.backend = AutoGPTQFormat("auto_round:auto_gptq", ar) elif ar.bits == 4 and not ar.sym and "int" in ar.data_type: if ar.layer_config is None: @@ -1098,6 +1102,8 @@ def __init__(self, format: str, ar: BaseCompressor): self.backend = AutoAWQFormat("auto_round:auto_awq", ar) elif is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): self.backend = AutoRoundFormat(ar.data_type, ar) + elif is_mx_int(ar.data_type) and ar.bits == 4: # only add mx_int4 now + self.backend = AutoRoundFormat(ar.data_type, ar) elif is_static_wfp8afp8(ar): # static wfp8afp8 self.backend = AutoRoundFormat(AutoRoundExportFormat.FP8_STATIC.value, ar) elif ar.data_type.startswith("fp") and ar.bits == 8 and ar.act_bits >= 16: # woq fp8 @@ -1155,7 +1161,11 @@ def pack_layer(self, layer_name, model, device=None, **kwargs): f"auto_round:{AutoRoundExportFormat.MX_FP_RCEIL.value}", f"auto_round:{AutoRoundExportFormat.NV_FP4_WITH_STATIC_GS.value}", ]: - from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import pack_layer + from auto_round.export.export_to_autoround.export_to_nvfp_mx import pack_layer + + pack_func = pack_layer + elif self.output_format in [f"auto_round:{AutoRoundExportFormat.MX_INT.value}"]: + from auto_round.export.export_to_autoround.export_to_nvfp_mx import pack_layer pack_func = pack_layer elif self.output_format in [ @@ -1196,7 +1206,7 @@ def save_quantized( ) backend = self.get_backend_name() if re.search(f"{AutoRoundExportFormat.MX_FP.value}|{AutoRoundExportFormat.NV_FP.value}", backend): - from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import save_quantized_as_fp + from auto_round.export.export_to_autoround.export_to_nvfp_mx import save_quantized_as_fp backend = "auto_round:llm_compressor" export_func = save_quantized_as_fp @@ -1205,6 +1215,11 @@ def save_quantized( backend = "auto_round:fp8_static" if serialization_dict.get("act_bits", 16) == 8 else None export_func = save_quantized_as_autoround + elif re.search(f"{AutoRoundExportFormat.MX_INT.value}", backend): + from auto_round.export.export_to_autoround.export_to_nvfp_mx import save_quantized_as_fp + + backend = "auto_round:mx_int4" + export_func = save_quantized_as_fp else: from auto_round.export.export_to_autoround.export import save_quantized_as_autoround diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index c5b10c4dc..31591b5ea 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -114,6 +114,7 @@ class BackendInfo: MX_TENSOR_DATA_TYPES = [ "mx_fp", "mx_fp_rceil", + "mx_int", ] @@ -303,6 +304,26 @@ def fp8_static_scheme_checker( requirements=["auto-round>0.7.0"], ) +# MXINT4 +BackendInfos["auto_round:torch_mxint4"] = BackendInfo( + device=["cuda", "cpu"], + packing_format=["auto_round:mx_int4"], + sym=[True], + compute_dtype=["float32", "float16", "bfloat16"], + data_type=MX_TENSOR_DATA_TYPES, + group_size=[32], + bits=[4], + act_bits=[4], + act_group_size=[32], + act_sym=[True], + act_data_type=MX_TENSOR_DATA_TYPES, + act_dynamic=[True], + priority=0, + checkers=[mxfp_nvfp_feature_checker], + alias=["auto_round", "torch"], + requirements=["auto-round>0.12.0"], +) + # NVFP4 BackendInfos["auto_round:torch_nvfp4"] = BackendInfo( @@ -774,6 +795,8 @@ def dynamic_import_inference_linear(backend, config): return ar_qmodules.WeightFP8ActFP8StaticQuantLinear if "torch_mxfp8" in backend: return ar_qmodules.MXFP8QuantLinear + if "torch_mxint4" in backend: + return ar_qmodules.MXINT4QuantLinear if "torch_mxfp4" in backend: hadamard_config = getattr(config, "hadamard_config", None) if hadamard_config is not None and hadamard_config: diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index a5b9096b3..9f2d6fca3 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -447,6 +447,7 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) or AutoRoundExportFormat.MXFP8.value in layer_backend or AutoRoundExportFormat.MXFP4.value in layer_backend or AutoRoundExportFormat.NVFP4.value in layer_backend + or AutoRoundExportFormat.MXINT4.value in layer_backend ): return QuantLinear.from_original(config, layer) diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 5318b1fec..2e0641554 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -228,6 +228,18 @@ def is_preset_scheme(name: str) -> bool: } ) +MXINT4 = QuantizationScheme.from_dict( + { + "bits": 4, + "group_size": 32, + "data_type": "mx_int", + "act_bits": 4, + "act_data_type": "mx_int", + "act_group_size": 32, + "act_sym": True, + "act_dynamic": True, + } +) NVFP4 = QuantizationScheme.from_dict( { @@ -330,6 +342,7 @@ def is_preset_scheme(name: str) -> bool: "W4A16_MIXED": W4A16, "INT8_W8A8": INT8_W8A8, "FP8_BLOCK": FP8_BLOCK, + "MXINT4": MXINT4, } from auto_round.export.export_to_gguf.config import GGUF_CONFIG diff --git a/docs/step_by_step.md b/docs/step_by_step.md index a076e9acb..e0573adb3 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -157,7 +157,7 @@ adopted within the community, **only 4-bits quantization is supported**. Please | Format | Supported Schemes | |:---|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| **auto_round** | W4A16, W2A16, W3A16, W8A16, W2A16G64, W2A16G32, `MXFP4`, `MXFP8`, `MXFP4_RCEIL`, `MXFP8_RCEIL`, `NVFP4`, `FPW8A16`, `FP8_STATIC`, `FP8_BLOCK`, `BF16` | +| **auto_round** | W4A16, W2A16, W3A16, W8A16, W2A16G64, W2A16G32, `MXFP4`, `MXFP8`, `MXFP4_RCEIL`, `MXFP8_RCEIL`, `NVFP4`, `FPW8A16`, `FP8_STATIC`, `FP8_BLOCK`, `BF16`, `MXINT4` | | **auto_awq** | W4A16, BF16 | | **auto_gptq** | W4A16, W2A16, W3A16, W8A16,W2A16G64, W2A16G32, BF16 | | **llm_compressor** | NVFP4, `MXFP4`, `MXFP8`, `FPW8A16`, `FP8_STATIC`, FP8_BLOCK | diff --git a/docs/step_by_step_CN.md b/docs/step_by_step_CN.md index b7cd57f64..a85cbd5d4 100644 --- a/docs/step_by_step_CN.md +++ b/docs/step_by_step_CN.md @@ -147,7 +147,7 @@ AutoRound 支持多种量化配置: | 格式 | 支持的量化方案 | |:-------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **auto_round** | W4A16、W2A16、W3A16、W8A16、W2A16G64、W2A16G32、`MXFP4`、`MXFP8`、`MXFP4_RCEIL`、`MXFP8_RCEIL`、`NVFP4`、`FPW8A16`、`FP8_STATIC`、`FP8_BLOCK`、`BF16` | +| **auto_round** | W4A16、W2A16、W3A16、W8A16、W2A16G64、W2A16G32、`MXFP4`、`MXFP8`、`MXFP4_RCEIL`、`MXFP8_RCEIL`、`NVFP4`、`FPW8A16`、`FP8_STATIC`、`FP8_BLOCK`、`BF16`, `MXINT4` | | **auto_awq** | W4A16、BF16 | | **auto_gptq** | W4A16、W2A16、W3A16、W8A16、W2A16G64、W2A16G32、BF16 | | **llm_compressor** | NVFP4、`MXFP4`、`MXFP8`、`FPW8A16`、`FP8_STATIC`、FP8_STATIC | diff --git a/test/test_cpu/quantization/test_mx_quant_linear.py b/test/test_cpu/quantization/test_mx_quant_linear.py index c2e9a3c00..1ec5cb729 100644 --- a/test/test_cpu/quantization/test_mx_quant_linear.py +++ b/test/test_cpu/quantization/test_mx_quant_linear.py @@ -4,13 +4,16 @@ from auto_round.data_type.utils import get_quant_func from auto_round.experimental import qmodules as ar_qmodules from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear as _MXFPLinear +from auto_round.export.export_to_autoround.qlinear_int import QuantLinear as _MXINTLinear from auto_round.formats import AutoRoundExportFormat from auto_round.schemes import PRESET_SCHEMES mx_schemes = [AutoRoundExportFormat.MXFP8.value, AutoRoundExportFormat.MXFP4.value] +mx_int_schemes = [AutoRoundExportFormat.MXINT4.value] QMODULE_MAPPING = { AutoRoundExportFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, AutoRoundExportFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, + AutoRoundExportFormat.MXINT4.value: ar_qmodules.MXINT4QuantLinear, } @@ -107,3 +110,86 @@ def test_mxquantlinear_from_original_and_forward(scheme): # Assert that the outputs are close within a tolerance assert diff_amax < 5e-1, f"Outputs differ too much for scheme {scheme}!" + + +@pytest.mark.parametrize("scheme", mx_int_schemes) +@torch.inference_mode() +def test_mxint_quantlinear_from_original_and_forward(scheme): + """ + Test MXINT4 quantization schemes by creating quantized layers + from an original torch.nn.Linear layer and validating their forward pass. + """ + # Set random seed for reproducibility + torch.manual_seed(42) + + # Define layer dimensions + in_features = 64 + out_features = 512 + + # Create an original torch.nn.Linear layer + original_layer = torch.nn.Linear(in_features, out_features, bias=False) + + # Select the quantization scheme + config = PRESET_SCHEMES[scheme.upper()] + + # Define weight scale shape + weight_scale_shape = (out_features, in_features // config.group_size) + + # Quantize the weights using the quantization function + qdq_func, _ = get_quant_func(dtype=config.data_type, bits=config.bits, sym=config.sym) + qdq_weight, shared_exp, _ = qdq_func( + tensor=original_layer.weight, + bits=config.bits, + group_size=config.group_size, + data_type=config.data_type + str(config.bits), + ) + shared_exp = shared_exp.reshape(weight_scale_shape) + + # Pack the weights using the QuantLinear class + mxint_lin = _MXINTLinear( + bits=config.bits, + group_size=config.group_size, + infeatures=in_features, + outfeatures=out_features, + bias=original_layer.bias is not None, + data_type=config.data_type, + ) + mxint_lin.pack(linear=original_layer, scales=shared_exp) + + # Create an MXQuantLinear layer from the original layer + QuantLinearClass = QMODULE_MAPPING[scheme] + mxint_layer = QuantLinearClass.from_original( + config=config, + original_layer=original_layer, + ) + + # Copy the packed weights and scales to the quantized layer + packed_weight = mxint_lin.weight_packed + if config.bits == 4: + mxint_layer.weight_packed.data.copy_(packed_weight) + else: + raise ValueError("Only 4-bit quantization are supported.") + mxint_layer.weight_scale.data.copy_(mxint_lin.weight_scale) + + # Validate layer attributes + assert mxint_layer.in_features == original_layer.in_features + assert mxint_layer.out_features == original_layer.out_features + + # Generate a random input tensor + input_tensor = torch.randn((4, in_features), dtype=torch.float32) + + # Perform a forward pass with both layers + original_output = original_layer(input_tensor) + mx_output = mxint_layer(input_tensor) + + # Compute the difference between the outputs + diff = mx_output - original_output + # Note: Remove NaN values, as we might get NaN when casting scales to FP8 + diff = diff[~torch.isnan(diff)] + diff_amax = diff.abs().max() + + # Print the maximum difference for debugging + print(f"Scheme: {scheme}, Max Difference: {diff_amax}") + + # Assert that the outputs are close within a tolerance + assert diff_amax < 5e-1, f"Outputs differ too much for scheme {scheme}!" diff --git a/test/test_cpu/quantization/test_mxfp_save_load.py b/test/test_cpu/quantization/test_mxfp_save_load.py index 5e12edc68..25e5a2428 100644 --- a/test/test_cpu/quantization/test_mxfp_save_load.py +++ b/test/test_cpu/quantization/test_mxfp_save_load.py @@ -28,11 +28,12 @@ AutoRoundExportFormat.MXFP8.value: ar_schemes.MXFP8, AutoRoundExportFormat.MXFP4.value: ar_schemes.MXFP4, } +MX_TENSOR_DATA_TYPES_FP = [i for i in MX_TENSOR_DATA_TYPES if "int" not in i] @pytest.mark.parametrize("scheme_name", testing_scheme_name_lst) -@pytest.mark.parametrize("weight_data_type", MX_TENSOR_DATA_TYPES) -@pytest.mark.parametrize("act_data_type", MX_TENSOR_DATA_TYPES) +@pytest.mark.parametrize("weight_data_type", MX_TENSOR_DATA_TYPES_FP) +@pytest.mark.parametrize("act_data_type", MX_TENSOR_DATA_TYPES_FP) @torch.inference_mode() def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type): # Use a temporary directory for saving the quantized model