fix hadamard transform weight dtype, using float32 as default and in-place transformed weight .#1665
fix hadamard transform weight dtype, using float32 as default and in-place transformed weight .#1665lkk12014402 wants to merge 5 commits intointel:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR changes the default Hadamard transform weight dtype to torch.float64 and adjusts the activation/weight transform application paths to align input/output dtypes with the Hadamard matrix, restoring original input dtypes after the transform.
Changes:
- Default Hadamard transform
precisiontotorch.float64and document the behavior. - Remove explicit
precision=module.dtype/precision=module.weight.dtypewhen building transforms, relying on the new default. - Add dtype-casting logic in both the Triton (
mxfp4_forward_kernel_wrapper) and non-Triton hook paths, and cast outputs back to the original input dtype.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
auto_round/experimental/transform/triton/mxfp4.py |
Casts x to the Hadamard matrix dtype before launching the Triton kernel. |
auto_round/experimental/transform/hadamards.py |
Sets default Hadamard weight precision to float64 and adds an expanded docstring. |
auto_round/experimental/transform/apply.py |
Stops passing module dtype into transform construction; restores original activation dtype after applying transforms. |
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
for more information, see https://pre-commit.ci
|
I was mistaken. It makes sense to use higher precision for offline transformations; however, for online transformations, using torch.float64 would be significantly more costly I guess. |
There is no obvious performance degradation. Using float64 is about 10% slower than bfloat16, and after replacing the dtype with float32, it is about 1~2% slower than bfloat16. (I get the data by testing piqa task with lm_eval, and the backend is hf) |
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
for more information, see https://pre-commit.ci
Description
using hadamard transform float64 weight as default.