Skip to content

fix hadamard transform weight dtype, using float32 as default and in-place transformed weight .#1665

Open
lkk12014402 wants to merge 5 commits intointel:mainfrom
lkk12014402:fix_hadamard_transform_dtype
Open

fix hadamard transform weight dtype, using float32 as default and in-place transformed weight .#1665
lkk12014402 wants to merge 5 commits intointel:mainfrom
lkk12014402:fix_hadamard_transform_dtype

Conversation

@lkk12014402
Copy link
Copy Markdown
Contributor

Description

using hadamard transform float64 weight as default.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR changes the default Hadamard transform weight dtype to torch.float64 and adjusts the activation/weight transform application paths to align input/output dtypes with the Hadamard matrix, restoring original input dtypes after the transform.

Changes:

  • Default Hadamard transform precision to torch.float64 and document the behavior.
  • Remove explicit precision=module.dtype / precision=module.weight.dtype when building transforms, relying on the new default.
  • Add dtype-casting logic in both the Triton (mxfp4_forward_kernel_wrapper) and non-Triton hook paths, and cast outputs back to the original input dtype.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
auto_round/experimental/transform/triton/mxfp4.py Casts x to the Hadamard matrix dtype before launching the Triton kernel.
auto_round/experimental/transform/hadamards.py Sets default Hadamard weight precision to float64 and adds an expanded docstring.
auto_round/experimental/transform/apply.py Stops passing module dtype into transform construction; restores original activation dtype after applying transforms.

@wenhuach21
Copy link
Copy Markdown
Contributor

I was mistaken. It makes sense to use higher precision for offline transformations; however, for online transformations, using torch.float64 would be significantly more costly I guess.

@lkk12014402
Copy link
Copy Markdown
Contributor Author

lkk12014402 commented Apr 7, 2026

I was mistaken. It makes sense to use higher precision for offline transformations; however, for online transformations, using torch.float64 would be significantly more costly I guess.

There is no obvious performance degradation. Using float64 is about 10% slower than bfloat16, and after replacing the dtype with float32, it is about 1~2% slower than bfloat16. (I get the data by testing piqa task with lm_eval, and the backend is hf)

Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
@lkk12014402 lkk12014402 changed the title fix hadamard transform weight dtype, using float64 as default. fix hadamard transform weight dtype, using float32 as default and in-place transformed weight . Apr 8, 2026
lkk12014402 and others added 2 commits April 8, 2026 08:56
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants