Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions auto_round/experimental/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def _apply_to_module(
location="input",
inverse=True,
device="cpu",
precision=module.dtype,
)

if config.hadamard_type != "random_hadamard":
Expand All @@ -115,14 +114,15 @@ def input_hook(self, args):
input = args[0]
# transform(input)
orig_shape = input.shape
orig_dtype = input.dtype
x_flat = input.contiguous().flatten(end_dim=-2)
qdq_input, _ = mxfp4_forward_kernel_wrapper(
x_flat,
(
hadamard_weight if hadamard_weight is not None else self.hadamard_matrix.T
), # this matrix from w_transform, needs transpose
)
return qdq_input.reshape(orig_shape)
return qdq_input.reshape(orig_shape).to(orig_dtype)

# for fused transform + quantization kernel
module.pre_dequantized_input = True
Expand All @@ -135,13 +135,22 @@ def input_hook(self, args):
input = args[0]

ori_shape = input.shape
orig_dtype = input.dtype

if hadamard_weight is not None:
input = input.view(-1, hadamard_weight.shape[0])
return _multihead_matmul(input, hadamard_weight.to(input.device)).view(ori_shape)
return (
(_multihead_matmul(input.to(hadamard_weight.dtype), hadamard_weight.to(input.device)))
.view(ori_shape)
.to(orig_dtype)
)
else:
input = input.view(-1, self.hadamard_matrix.shape[0])
return _multihead_matmul(input, self.hadamard_matrix.T).view(ori_shape)
return (
(_multihead_matmul(input.to(self.hadamard_matrix.dtype), self.hadamard_matrix.T))
.view(ori_shape)
.to(orig_dtype)
)

# for fused transform + quantization kernel
module.pre_dequantized_input = False
Expand All @@ -156,7 +165,6 @@ def input_hook(self, args):
**config.dict(),
location="weight",
device=module.weight.device,
precision=module.weight.dtype,
)

# need save random hadamard matrix needed when inference
Expand All @@ -167,31 +175,22 @@ def input_hook(self, args):

patch_quantlinear(config.hadamard_type)

if need_calibration:
# for training, the weight changes with every forward pass
# for autoround tuning: patch wrapper linear qdq_weight func
from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_to_apply_transform,
patch_wrapperwalayer_forward_to_apply_transform,
)

input_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="input",
inverse=True,
device=module.weight.device,
precision=module.weight.dtype,
)

patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform)
patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform)
# for autoround tuning: weight not tuning
# for rtn: weight transformed before saving
from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_to_apply_transform,
patch_wrapperwalayer_forward_to_apply_transform,
)

else:
# transform is no longer needed (unfusing is not supported)
# delattr(module, transform_name)
# fuse transform into weight
with torch.no_grad():
getattr(module, "weight").copy_(weight_hadamard_transform(module.weight).to(module.weight.device))
input_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="input",
inverse=True,
device=module.weight.device,
)

patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform)
patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform)

else:
# TODO: apply transform to output/q/k
Expand Down
16 changes: 15 additions & 1 deletion auto_round/experimental/transform/hadamards.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,25 @@ def __init__(
self,
block_size: int = 32,
device: torch.device = None,
precision: torch.dtype = None,
precision: torch.dtype = torch.float32,
location: str = "weight",
module_type: type[torch.nn.Module] = torch.nn.Linear,
inverse: bool = False,
):
"""Initialize a Hadamard transform module.

Args:
block_size: Size of each Hadamard block. The input tensor is reshaped
to ``(-1, block_size)`` before applying the transform.
device: Device on which to create the Hadamard matrix.
precision: Data type used for the Hadamard matrix weights, using float64 as default.
location: Target location used by ``apply_transform_weight`` when
applying the transform.
module_type: Module type associated with the transform application,
typically ``torch.nn.Linear``.
inverse: Whether to build the inverse form of the transform.
"""

super().__init__()
self.size = block_size
self.scale = 1 / math.sqrt(self.size)
Expand Down
71 changes: 17 additions & 54 deletions auto_round/experimental/transform/patch_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,67 +32,30 @@ def _qdq_weight_patched(self, value, min_scale, max_scale):
# keep original behavior for >=16bit to avoid changing semantics unexpectedly
return orig_qdq_weight(self, value, min_scale, max_scale)

min_scale.data.clamp_(0, 1.0)
max_scale.data.clamp_(0, 1.0)
if getattr(self, "applied_weight_hadamard", None) is None:
with torch.no_grad():
weight = self.orig_layer.weight
if weight.device.type == "meta":
weight = self.orig_layer.get_weight().to(self.device)

weight = self.orig_layer.weight
if weight.device.type == "meta":
weight = self.orig_layer.get_weight().to(self.device)
is_conv1d = type(self.orig_layer) == transformers.pytorch_utils.Conv1D
if is_conv1d:
weight = weight.t().continuous()
new_weight = w_transform(weight)
if is_conv1d:
new_weight = weight.t().continuous()
self.orig_layer.weight.data.copy_(new_weight)
self.applied_weight_hadamard = True

is_conv1d = type(self.orig_layer) == transformers.pytorch_utils.Conv1D
if is_conv1d:
weight = weight.t()
return orig_qdq_weight(self, value, min_scale, max_scale)

weight = weight.to(self.device)

weight_t = w_transform(weight)

quant_kwargs = {}
if hasattr(self.orig_layer, "super_bits"):
quant_kwargs["super_bits"] = self.orig_layer.super_bits
quant_kwargs["super_group_size"] = self.orig_layer.super_group_size

weight_q, scale, zp = self.weight_quant_func(
weight_t,
bits=self.orig_layer.bits,
group_size=self.orig_layer.group_size,
v=value,
min_scale=min_scale,
max_scale=max_scale,
scale_dtype=self.orig_layer.scale_dtype,
tensor_min=self.weight_min,
tensor_max=self.weight_max,
data_type=self.data_type,
q_scale_thresh=self.q_scale_thresh,
imatrix=self.orig_layer.imatrix.to(self.device) if hasattr(self.orig_layer, "imatrix") else None,
global_scale=getattr(self, "weight_global_scale", None),
**quant_kwargs,
)

weight_q = weight_q.to(dtype=weight.dtype)

if is_conv1d:
weight_q = weight_q.t()

return weight_q, scale, zp
orig_qdq_act = WrapperLinear._qdq_act

def _qdq_act_patched(self, x, act_max_scale, act_max=None):

# transform = getattr(self.orig_layer, transform_attr)
x = inp_transform(x)
act_max_scale.data.clamp_(0, 1.0)
x, scale, zp = self.act_quant_func(
x,
bits=self.orig_layer.act_bits,
group_size=self.orig_layer.act_group_size,
scale_dtype=self.orig_layer.scale_dtype,
q_scale_thresh=self.q_scale_thresh,
data_type=self.act_data_type,
max_scale=act_max_scale,
tensor_max=act_max,
global_scale=getattr(self, "input_global_scale", None),
)
return x, scale, zp

return orig_qdq_act(self, x, act_max_scale, act_max)

WrapperLinear._qdq_weight = _qdq_weight_patched
WrapperLinear._qdq_act = _qdq_act_patched
Expand Down
5 changes: 5 additions & 0 deletions auto_round/experimental/transform/triton/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def mxfp4_forward_kernel_wrapper(
if hadamard_matrix.device != device:
hadamard_matrix = hadamard_matrix.to(device)

dtype = hadamard_matrix.dtype

if x.dtype != dtype:
x = x.to(dtype)

# Make sure inputs are contiguous
x = x.contiguous()
hadamard_matrix = hadamard_matrix.contiguous()
Expand Down
4 changes: 2 additions & 2 deletions auto_round/experimental/transform/utils/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def random_hadamard_matrix(
:param gen: Optional generator random values
:return: randomly generated hadamard matrix
"""
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
Q = Q.to(device=device)
Q = torch.randint(low=0, high=2, size=(size,), generator=gen) # cpu
Q = Q.to(device=device, dtype=dtype)
Q = Q * 2 - 1
Q = torch.diag(Q)
return _matmul_hadU(Q)
Expand Down