diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index a8735407e..074ffcf1c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -560,15 +560,10 @@ def __init__( # apply hadamard transform if hadamard_config: from auto_round.experimental.transform.apply import apply_hadamard_transform - from auto_round.experimental.utils import check_supported_schemes, normalize_hadamard_config + from auto_round.experimental.utils import normalize_hadamard_config - check_supported_schemes(self.scheme) - - self.model = apply_hadamard_transform( - self.model, hadamard_config, need_calibration=True if self.iters > 0 else False - ) - - self.hadamard_config = normalize_hadamard_config(hadamard_config) + self.hadamard_config = normalize_hadamard_config(hadamard_config, self.scheme) + self.model = apply_hadamard_transform(self.model, self.hadamard_config, scheme=self.scheme) def _gen_auto_scheme(self) -> dict[str, dict]: if self.mllm: diff --git a/auto_round/experimental/qmodules/__init__.py b/auto_round/experimental/qmodules/__init__.py index 3862e0293..377784055 100644 --- a/auto_round/experimental/qmodules/__init__.py +++ b/auto_round/experimental/qmodules/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear, HadamardMXFP4QuantLinear -from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear +from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear, HadamardNVFP4QuantLinear from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear diff --git a/auto_round/experimental/qmodules/nvfp4.py b/auto_round/experimental/qmodules/nvfp4.py index 81aea8b54..c82846f44 100644 --- a/auto_round/experimental/qmodules/nvfp4.py +++ b/auto_round/experimental/qmodules/nvfp4.py @@ -204,3 +204,21 @@ def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: m, half_n = packed_data.shape unpacked_data = unpack_fp4_from_uint8(packed_data, m, half_n * 2, dtype=self.dtype) return unpacked_data + + +class HadamardNVFP4QuantLinear(NVFP4QuantLinear): + """ + Quantized linear layer using the NVFP4 quantization scheme. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.enable_transform = True + self.register_buffer( + "hadamard_matrix", + torch.empty( + self.group_size, + self.group_size, + dtype=self.dtype, + ), + ) diff --git a/auto_round/experimental/transform/apply.py b/auto_round/experimental/transform/apply.py index 6980d75e4..d99e0b928 100644 --- a/auto_round/experimental/transform/apply.py +++ b/auto_round/experimental/transform/apply.py @@ -4,7 +4,7 @@ import torch import tqdm -from auto_round.experimental.qmodules.mx import MXQuantLinearBase +from auto_round.experimental.qmodules.base import QModuleBase from auto_round.experimental.transform.hadamard_config import HadamardConfig from auto_round.experimental.transform.hadamards import build_hadamard_transform from auto_round.experimental.utils import is_triton_kernel_available, normalize_hadamard_config @@ -15,10 +15,11 @@ def apply_hadamard_transform( model: torch.nn.Module, config: str | dict | HadamardConfig | None, - need_calibration: bool = False, location: str = "weight", use_tqdm=True, desc=None, + data_type="mx_fp", + scheme="MXFP4", ): """ Apply a transform configuration to a model. @@ -53,21 +54,21 @@ def apply_hadamard_transform( ``config.transform_type``. """ - config = normalize_hadamard_config(config) + config = normalize_hadamard_config(config, scheme) if not isinstance(config, HadamardConfig): config = HadamardConfig(**config) modules_config = [ (name, module, config) for name, module in model.named_modules() - if isinstance(module, torch.nn.Linear) or isinstance(module, MXQuantLinearBase) + if isinstance(module, torch.nn.Linear) or isinstance(module, QModuleBase) ] desc = f"Applying {config.hadamard_type} transforms" if desc is None else desc for name, module, config in tqdm.tqdm(modules_config, desc=desc, disable=(not use_tqdm)): if "lm_head" in name: continue - _apply_to_module(model, module, config, need_calibration, location) + _apply_to_module(model, module, config, location, data_type) # attach config to model for compression/serialization setattr(model, "hadamard_config", config) @@ -79,8 +80,8 @@ def _apply_to_module( model: torch.nn.Module, module: torch.nn.Module, config: HadamardConfig, - need_calibration: bool = False, location: str = "weight", + data_type: str = "mx_fp", ): """ Create transforms and apply them to the module @@ -100,7 +101,6 @@ def _apply_to_module( location="input", inverse=True, device="cpu", - precision=module.dtype, ) if config.hadamard_type != "random_hadamard": @@ -108,13 +108,14 @@ def _apply_to_module( else: hadamard_weight = None - if is_triton_kernel_available(): + if is_triton_kernel_available(data_type): from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper def input_hook(self, args): input = args[0] # transform(input) orig_shape = input.shape + orig_dtype = input.dtype x_flat = input.contiguous().flatten(end_dim=-2) qdq_input, _ = mxfp4_forward_kernel_wrapper( x_flat, @@ -122,7 +123,7 @@ def input_hook(self, args): hadamard_weight if hadamard_weight is not None else self.hadamard_matrix.T ), # this matrix from w_transform, needs transpose ) - return qdq_input.reshape(orig_shape) + return qdq_input.reshape(orig_shape).to(orig_dtype) # for fused transform + quantization kernel module.pre_dequantized_input = True @@ -135,13 +136,22 @@ def input_hook(self, args): input = args[0] ori_shape = input.shape + orig_dtype = input.dtype if hadamard_weight is not None: input = input.view(-1, hadamard_weight.shape[0]) - return _multihead_matmul(input, hadamard_weight.to(input.device)).view(ori_shape) + return ( + (_multihead_matmul(input.to(hadamard_weight.dtype), hadamard_weight.to(input.device))) + .view(ori_shape) + .to(orig_dtype) + ) else: input = input.view(-1, self.hadamard_matrix.shape[0]) - return _multihead_matmul(input, self.hadamard_matrix.T).view(ori_shape) + return ( + (_multihead_matmul(input.to(self.hadamard_matrix.dtype), self.hadamard_matrix.T)) + .view(ori_shape) + .to(orig_dtype) + ) # for fused transform + quantization kernel module.pre_dequantized_input = False @@ -156,7 +166,6 @@ def input_hook(self, args): **config.dict(), location="weight", device=module.weight.device, - precision=module.weight.dtype, ) # need save random hadamard matrix needed when inference @@ -167,31 +176,22 @@ def input_hook(self, args): patch_quantlinear(config.hadamard_type) - if need_calibration: - # for training, the weight changes with every forward pass - # for autoround tuning: patch wrapper linear qdq_weight func - from auto_round.experimental.transform.patch_modules import ( - patch_wrapperlinear_to_apply_transform, - patch_wrapperwalayer_forward_to_apply_transform, - ) - - input_hadamard_transform = build_hadamard_transform( - **config.dict(), - location="input", - inverse=True, - device=module.weight.device, - precision=module.weight.dtype, - ) - - patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform) - patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform) + # for autoround tuning: weight not tuning + # for rtn: weight transformed before saving + from auto_round.experimental.transform.patch_modules import ( + patch_wrapperlinear_to_apply_transform, + patch_wrapperwalayer_forward_to_apply_transform, + ) - else: - # transform is no longer needed (unfusing is not supported) - # delattr(module, transform_name) - # fuse transform into weight - with torch.no_grad(): - getattr(module, "weight").copy_(weight_hadamard_transform(module.weight).to(module.weight.device)) + input_hadamard_transform = build_hadamard_transform( + **config.dict(), + location="input", + inverse=True, + device=module.weight.device, + ) + + patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform) + patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform) else: # TODO: apply transform to output/q/k diff --git a/auto_round/experimental/transform/hadamards.py b/auto_round/experimental/transform/hadamards.py index 712232a9a..0d29e5cb0 100644 --- a/auto_round/experimental/transform/hadamards.py +++ b/auto_round/experimental/transform/hadamards.py @@ -34,11 +34,25 @@ def __init__( self, block_size: int = 32, device: torch.device = None, - precision: torch.dtype = None, + precision: torch.dtype = torch.float32, location: str = "weight", module_type: type[torch.nn.Module] = torch.nn.Linear, inverse: bool = False, ): + """Initialize a Hadamard transform module. + + Args: + block_size: Size of each Hadamard block. The input tensor is reshaped + to ``(-1, block_size)`` before applying the transform. + device: Device on which to create the Hadamard matrix. + precision: Data type used for the Hadamard matrix weights, using float64 as default. + location: Target location used by ``apply_transform_weight`` when + applying the transform. + module_type: Module type associated with the transform application, + typically ``torch.nn.Linear``. + inverse: Whether to build the inverse form of the transform. + """ + super().__init__() self.size = block_size self.scale = 1 / math.sqrt(self.size) @@ -78,10 +92,14 @@ def forward(self, x: torch.Tensor): class RandomHadamardTransform(HadamardTransform): def __init__( self, - *args, + block_size: int = 32, + device: torch.device = None, + precision: torch.dtype = None, + location: str = "weight", + module_type: type[torch.nn.Module] = torch.nn.Linear, + inverse: bool = False, seed: int | None = None, generator: torch.Generator | None = None, - **kwargs, ): if generator is not None: self.generator = generator @@ -89,7 +107,15 @@ def __init__( self.generator = torch.Generator() if seed is not None: self.generator.manual_seed(seed) - super().__init__(*args, **kwargs) + + super().__init__( + block_size=block_size, + device=device, + precision=precision, + location=location, + module_type=module_type, + inverse=inverse, + ) def _create_weight( self, diff --git a/auto_round/experimental/transform/patch_modules.py b/auto_round/experimental/transform/patch_modules.py index 934ebea9d..e099a518d 100644 --- a/auto_round/experimental/transform/patch_modules.py +++ b/auto_round/experimental/transform/patch_modules.py @@ -32,67 +32,30 @@ def _qdq_weight_patched(self, value, min_scale, max_scale): # keep original behavior for >=16bit to avoid changing semantics unexpectedly return orig_qdq_weight(self, value, min_scale, max_scale) - min_scale.data.clamp_(0, 1.0) - max_scale.data.clamp_(0, 1.0) + if getattr(self, "applied_weight_hadamard", None) is None: + with torch.no_grad(): + weight = self.orig_layer.weight + if weight.device.type == "meta": + weight = self.orig_layer.get_weight().to(self.device) - weight = self.orig_layer.weight - if weight.device.type == "meta": - weight = self.orig_layer.get_weight().to(self.device) + is_conv1d = type(self.orig_layer) == transformers.pytorch_utils.Conv1D + if is_conv1d: + weight = weight.t().continuous() + new_weight = w_transform(weight) + if is_conv1d: + new_weight = weight.t().continuous() + self.orig_layer.weight.data.copy_(new_weight) + self.applied_weight_hadamard = True - is_conv1d = type(self.orig_layer) == transformers.pytorch_utils.Conv1D - if is_conv1d: - weight = weight.t() + return orig_qdq_weight(self, value, min_scale, max_scale) - weight = weight.to(self.device) - - weight_t = w_transform(weight) - - quant_kwargs = {} - if hasattr(self.orig_layer, "super_bits"): - quant_kwargs["super_bits"] = self.orig_layer.super_bits - quant_kwargs["super_group_size"] = self.orig_layer.super_group_size - - weight_q, scale, zp = self.weight_quant_func( - weight_t, - bits=self.orig_layer.bits, - group_size=self.orig_layer.group_size, - v=value, - min_scale=min_scale, - max_scale=max_scale, - scale_dtype=self.orig_layer.scale_dtype, - tensor_min=self.weight_min, - tensor_max=self.weight_max, - data_type=self.data_type, - q_scale_thresh=self.q_scale_thresh, - imatrix=self.orig_layer.imatrix.to(self.device) if hasattr(self.orig_layer, "imatrix") else None, - global_scale=getattr(self, "weight_global_scale", None), - **quant_kwargs, - ) - - weight_q = weight_q.to(dtype=weight.dtype) - - if is_conv1d: - weight_q = weight_q.t() - - return weight_q, scale, zp + orig_qdq_act = WrapperLinear._qdq_act def _qdq_act_patched(self, x, act_max_scale, act_max=None): - # transform = getattr(self.orig_layer, transform_attr) x = inp_transform(x) - act_max_scale.data.clamp_(0, 1.0) - x, scale, zp = self.act_quant_func( - x, - bits=self.orig_layer.act_bits, - group_size=self.orig_layer.act_group_size, - scale_dtype=self.orig_layer.scale_dtype, - q_scale_thresh=self.q_scale_thresh, - data_type=self.act_data_type, - max_scale=act_max_scale, - tensor_max=act_max, - global_scale=getattr(self, "input_global_scale", None), - ) - return x, scale, zp + + return orig_qdq_act(self, x, act_max_scale, act_max) WrapperLinear._qdq_weight = _qdq_weight_patched WrapperLinear._qdq_act = _qdq_act_patched diff --git a/auto_round/experimental/transform/triton/mxfp4.py b/auto_round/experimental/transform/triton/mxfp4.py index c26413248..8028c167b 100644 --- a/auto_round/experimental/transform/triton/mxfp4.py +++ b/auto_round/experimental/transform/triton/mxfp4.py @@ -161,6 +161,11 @@ def mxfp4_forward_kernel_wrapper( if hadamard_matrix.device != device: hadamard_matrix = hadamard_matrix.to(device) + dtype = hadamard_matrix.dtype + + if x.dtype != dtype: + x = x.to(dtype) + # Make sure inputs are contiguous x = x.contiguous() hadamard_matrix = hadamard_matrix.contiguous() diff --git a/auto_round/experimental/transform/utils/hadamard.py b/auto_round/experimental/transform/utils/hadamard.py index 5ec6bccbd..5c7ade385 100644 --- a/auto_round/experimental/transform/utils/hadamard.py +++ b/auto_round/experimental/transform/utils/hadamard.py @@ -70,8 +70,8 @@ def random_hadamard_matrix( :param gen: Optional generator random values :return: randomly generated hadamard matrix """ - Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu - Q = Q.to(device=device) + Q = torch.randint(low=0, high=2, size=(size,), generator=gen) # cpu + Q = Q.to(device=device, dtype=dtype) Q = Q * 2 - 1 Q = torch.diag(Q) return _matmul_hadU(Q) diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index 39a7ff135..b32132c7a 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -16,11 +16,12 @@ import torch +from auto_round.compressors.utils import is_nv_fp from auto_round.experimental.transform.hadamard_config import HadamardConfig from auto_round.experimental.transform.hadamards import HADAMARDS from auto_round.utils import logger -SUPPORTED_QUANTIZATION_SCHEMES = ["MXFP4"] +SUPPORTED_QUANTIZATION_SCHEMES = ["MXFP4", "NVFP4"] def per_tensor_fp8_qdq( @@ -114,10 +115,12 @@ def clean_model_parameters_and_buffers_(model: torch.nn.Module, name_tuple: tupl _clean_param_or_buff_if_exists(module, name_tuple) -def is_triton_kernel_available() -> bool: +def is_triton_kernel_available(data_type: str) -> bool: """ Best-effort check for whether Triton kernel path can be used. """ + if is_nv_fp(data_type): + return False try: import triton # pylint: disable=E0401 except Exception: @@ -134,60 +137,108 @@ def is_triton_kernel_available() -> bool: return True -def normalize_hadamard_config(hadamard_config: str | dict | HadamardConfig | None) -> dict[str, Any]: +def normalize_hadamard_config(hadamard_config: str | dict | HadamardConfig | None, scheme: str) -> dict[str, Any]: """ Normalize and validate `hadamard_config`. Supported input types: - - None -> {} - - dict -> validated via HadamardConfig + - None -> {} + - dict -> validated via HadamardConfig - HadamardConfig -> validated & converted to dict - - str -> shorthand for `transform_type` in TRANSFORMS keys - - On any validation failure, raises ValueError/TypeError. + - str -> shorthand for `hadamard_type` in HADAMARDS keys + + Additional behavior: + - If block_size is not set by user: + - MXFP4 -> default block_size to 32 + - NVFP4 -> default block_size to 16 + - other schemes -> emit a warning + - If block_size is set but does not match the recommended value: + - MXFP4 expects 32 + - NVFP4 expects 16 + - emit a warning """ + + def _normalize_scheme(s: str) -> str: + return s.strip().upper() + + def _apply_scheme_block_size(cfg_dict: dict[str, Any], block_size_explicitly_set: bool) -> dict[str, Any]: + normalized_scheme = _normalize_scheme(scheme) + block_size = cfg_dict.get("block_size") + + if not block_size_explicitly_set or block_size is None: + if normalized_scheme == "MXFP4": + cfg_dict["block_size"] = 32 + logger.warning("block_size is not set for scheme 'MXFP4'; defaulting to 32.") + elif normalized_scheme == "NVFP4": + cfg_dict["block_size"] = 16 + logger.warning("block_size is not set for scheme 'NVFP4'; defaulting to 16.") + else: + logger.warning( + f"block_size is not set and cannot be inferred for scheme {scheme!r}; " + "please set block_size explicitly in hadamard_config if needed." + ) + else: + if normalized_scheme == "MXFP4" and block_size != 32: + logger.warning(f"scheme is 'MXFP4' but block_size={block_size}; recommended value is 32.") + elif normalized_scheme == "NVFP4" and block_size != 16: + logger.warning(f"scheme is 'NVFP4' but block_size={block_size}; recommended value is 16.") + + return cfg_dict + # 1) None -> {} if hadamard_config is None: return {} - # 2) Already a HadamardConfig instance + # 2) HadamardConfig instance if isinstance(hadamard_config, HadamardConfig): - # Ensure it passes its own validation and convert to dict - cfg = HadamardConfig.model_validate(hadamard_config).model_dump() - return cfg + raw_cfg_dict = hadamard_config.model_dump(exclude_unset=True) + block_size_explicitly_set = "block_size" in raw_cfg_dict - # 3) dict -> validate via HadamardConfig + cfg_dict = dict(raw_cfg_dict) + cfg_dict = _apply_scheme_block_size(cfg_dict, block_size_explicitly_set) + + try: + return HadamardConfig.model_validate(cfg_dict).model_dump() + except Exception as e: + raise ValueError(f"Invalid HadamardConfig: {e}") from e + + # 3) dict if isinstance(hadamard_config, dict): + block_size_explicitly_set = "block_size" in hadamard_config + + cfg_dict = dict(hadamard_config) + cfg_dict = _apply_scheme_block_size(cfg_dict, block_size_explicitly_set) + try: - cfg = HadamardConfig.model_validate(hadamard_config).model_dump() + return HadamardConfig.model_validate(cfg_dict).model_dump() except Exception as e: raise ValueError(f"Invalid hadamard_config dict: {e}") from e - return cfg - # 4) str -> shorthand for transform_type + # 4) str -> shorthand for hadamard_type if isinstance(hadamard_config, str): key = hadamard_config.strip() if not key: return {} if key == "default": - cfg = HadamardConfig() - return cfg.model_dump() + cfg_dict = {} + cfg_dict = _apply_scheme_block_size(cfg_dict, block_size_explicitly_set=False) + try: + return HadamardConfig.model_validate(cfg_dict).model_dump() + except Exception as e: + raise ValueError(f"Invalid default hadamard_config after scheme adjustment: {e}") from e if key not in HADAMARDS: - raise ValueError( - f"Invalid hadamard_config string: {key!r}. " f"Expected one of {sorted(HADAMARDS.keys())}." - ) + raise ValueError(f"Invalid hadamard_config string: {key!r}. Expected one of {sorted(HADAMARDS.keys())}.") cfg_dict = {"hadamard_type": key} + cfg_dict = _apply_scheme_block_size(cfg_dict, block_size_explicitly_set=False) try: - cfg = HadamardConfig.model_validate(cfg_dict).model_dump() + return HadamardConfig.model_validate(cfg_dict).model_dump() except Exception as e: raise ValueError(f"hadamard_config built from string {key!r} is invalid for HadamardConfig: {e}") from e - return cfg - raise TypeError( "hadamard_config must be one of: None, dict, HadamardConfig, or str " f"(got {type(hadamard_config).__name__})" ) diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index d98545679..609091ce3 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -781,6 +781,10 @@ def dynamic_import_inference_linear(backend, config): return ar_qmodules.HadamardMXFP4QuantLinear return ar_qmodules.MXFP4QuantLinear if "torch_nvfp4" in backend: + hadamard_config = getattr(config, "hadamard_config", None) + if hadamard_config is not None and hadamard_config: + if hadamard_config["hadamard_type"] == "random_hadamard": + return ar_qmodules.HadamardNVFP4QuantLinear return ar_qmodules.NVFP4QuantLinear if "auto_round_kernel" in backend or "ark" in backend: diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index a5b9096b3..2e2002af0 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -687,7 +687,11 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M hadamard_type=hadamard_config["hadamard_type"], ) # apply to activation model = apply_hadamard_transform( - model, act_hadamard_config, location="input", desc="Register pre forward hook for hadamard transform" + model, + act_hadamard_config, + location="input", + desc="Register pre forward hook for hadamard transform", + data_type=quantization_config.data_type, ) # Suggest a better backend if available