Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,11 @@ def tune(args):
suffix = f"a{autoround.act_bits}"
else:
suffix = f"g{autoround.group_size}"
prefix = autoround.data_type.lower().replace("_", "") if "int" not in autoround.data_type else ""
prefix = (
autoround.data_type.lower().replace("_", "")
if "int" not in autoround.data_type or "mx" in autoround.data_type
else ""
)
export_dir = os.path.join(
args.output_dir,
model_name.split("/")[-1] + (f"-{prefix}" if prefix else "") + f"-w{autoround.bits}{suffix}",
Expand Down
6 changes: 6 additions & 0 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BackendDataType(str, Enum):
STANDARD_FP = "fp"
MX_FP = "mx_fp"
NV_FP = "nv_fp"
MX_INT = "mx_int"


def is_standard_fp(backend):
Expand All @@ -47,6 +48,11 @@ def is_mx_fp(backend):
return BackendDataType.MX_FP in backend


def is_mx_int(backend):
backend = backend.lower()
return BackendDataType.MX_INT in backend


def is_nv_fp(backend):
backend = backend.lower()
return BackendDataType.NV_FP in backend
Expand Down
7 changes: 6 additions & 1 deletion auto_round/experimental/qmodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, HadamardMXFP4QuantLinear
from auto_round.experimental.qmodules.mx import (
MXFP4QuantLinear,
MXFP8QuantLinear,
MXINT4QuantLinear,
HadamardMXFP4QuantLinear,
)
from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear
109 changes: 109 additions & 0 deletions auto_round/experimental/qmodules/int4_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch

_DEVICE_E0M4_TENSORS = {}

# Constants for INT4 values
_E0M4_VALUES = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75]


def get_e0m4_tensor(device):
"""Get device-specific E0M4 lookup tensor, creating it if needed."""
device_str = str(device)
if device_str not in _DEVICE_E0M4_TENSORS:
_DEVICE_E0M4_TENSORS[device_str] = torch.tensor(_E0M4_VALUES, dtype=torch.float32, device=device)
return _DEVICE_E0M4_TENSORS[device_str]


def unpack_int4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
"""
Unpacks uint8 values into int4. Each uint8 contains two int4 values
(low nibble first). The 4-bit indices are mapped to int4 values using kE0M4ToFloat.
"""
if a.device.type == "cuda":
return _unpack_int4_from_uint8_cuda(a, m, n, dtype)
else:
return _unpack_int4_from_uint8_cpu(a, m, n, dtype)


@torch.compiler.disable()
def _unpack_int4_from_uint8_cpu(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
return _unpack_int4_from_uint8(a, m, n, dtype)


# @torch.compile(fullgraph=True, dynamic=True)
def _unpack_int4_from_uint8_cuda(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
return _unpack_int4_from_uint8(a, m, n, dtype)


# reference: : https://github.com/vllm-project/vllm/pull/16362
def _unpack_int4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
"""
Unpacks uint8 values into int4. Each uint8 consists of two int4 values
(i.e. first four bits correspond to one int4 value, last four correspond to a
consecutive int4 value). The bits represent an index, which are mapped to an int4
value.

:param a: tensor to unpack
:param m: original dim 0 size of the unpacked tensor
:param n: original dim 1 size of the unpacked tensor
:param dtype: dense dtype to cast the unpacked tensor to
"""
assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}"

# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles

# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()

# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE0M4 = get_e0m4_tensor(device=a.device)
values = kE0M4[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)
44 changes: 43 additions & 1 deletion auto_round/experimental/qmodules/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from auto_round.data_type.utils import get_quant_func
from auto_round.experimental.qmodules.base import QModuleBase
from auto_round.experimental.qmodules.fp4_utils import unpack_fp4_from_uint8
from auto_round.experimental.qmodules.int4_utils import unpack_int4_from_uint8
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme

__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear"]
__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear", "MXINT4QuantLinear"]

SUPPORTED_HIGHER_DTYPE = [torch.bfloat16, torch.float16, torch.float32]
E8M0_EXPONENT_BIAS = 127
Expand Down Expand Up @@ -196,6 +197,47 @@ def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor:
return unpacked_data


class MXINT4QuantLinear(MXQuantLinearBase):
"""
Quantized linear layer using the MXINT4 quantization scheme.
"""

def __init__(self, *args, **kwargs):
self.weight_name = "weight_packed"
super().__init__(*args, **kwargs)

def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
weight_dtype = torch.uint8
weight_in_features = self.in_features // 2
return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight

def dequant_weight_online(self) -> torch.Tensor:
if self.pre_dequantized:
return self.weight
dq_weight = self.dequant_mx_tensor(self.weight_packed, self.weight_scale)
return dq_weight

def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor:
m, half_n = packed_data.shape
unpacked_data = unpack_int4_from_uint8(packed_data, m, half_n * 2, dtype=self.dtype)
return unpacked_data

@classmethod
def from_original(cls, config: Optional[QuantizationScheme], original_layer: torch.nn.Linear):
"""
Create an `MXQuantLinear` layer from an original linear layer.
"""
logger.warning_once("MXINT quantization is still in experimental stage, the inference speed might be slow.")
qdq_linear = cls(
in_features=original_layer.in_features,
out_features=original_layer.out_features,
config=config,
bias=original_layer.bias,
dtype=original_layer.weight.dtype,
)
return qdq_linear


class HadamardMXFP4QuantLinear(MXFP4QuantLinear):
"""
Quantized linear layer using the MXFP4 quantization scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
)
from auto_round.wrapper import WrapperWALayer

from .qlinear_fp import QuantLinear
from .qlinear_fp import QuantLinear as FpQuantLinear
from .qlinear_int import QuantLinear as IntQuantLinear

__all__ = [
"pack_layer",
Expand Down Expand Up @@ -94,7 +95,8 @@ def pack_layer(name, model, backend, device=None):

bias = layer.bias is not None
##bias = True ## if using the above, llama3 lambada RTN will be NAN , TODO why?
qlayer = QuantLinear( ##pylint: disable=E1123
linear_func = FpQuantLinear if "fp" in data_type else IntQuantLinear
qlayer = linear_func( ##pylint: disable=E1123
bits,
group_size,
in_features,
Expand Down
Loading
Loading