Fix ECC correction range: ±0.5 ULP → ±1 ULP#88
Fix ECC correction range: ±0.5 ULP → ±1 ULP#88josejg wants to merge 1 commit intoHomebrewML:mainfrom
Conversation
ECC with RNE produces errors ±0.5 ULP but when doing Stochastic Rounding the error range increases to ±1 ULP. Since heavyball defaults to SR, the `- 1` in `_log_ulp(x) - 1` limited the int8 correction to ±0.5 ULP. This caused frequent clamping in the ECC correction terms, introducing a per-step bias that accumulates through the EMA feedback loop. Removing the `- 1` from both decode and compute_correction solves it. Correction range doubles to ±1 ULP, clamping drops to 0%.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f5b526a141
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
|
||
| def decode(self, x): | ||
| ls = (_log_ulp(x) - 1).float() | ||
| ls = _log_ulp(x).float() |
There was a problem hiding this comment.
Preserve legacy ECC scale when decoding checkpoints
Updating decode to use _log_ulp(x) makes previously saved ECC corrections decode at twice their original magnitude, because older checkpoints encoded correction in ULP/2 units. The load path (ChainOpt._restore_ecc_dtypes) only restores dtypes and does not rescale or recompute existing ::ecc tensors, so resuming training from checkpoints created before this change will restore incorrect optimizer/parameter state values when ECC is enabled.
Useful? React with 👍 / 👎.
| narrow_f32 = self._bf16_to_f32(narrow) if narrow.dtype == torch.bfloat16 else narrow.float() | ||
| e = fp32 - narrow_f32 | ||
| ls = (_log_ulp(narrow) - 1).float() | ||
| ls = _log_ulp(narrow).float() |
There was a problem hiding this comment.
Keep ECC normalization at half ULP for deterministic encode
_ULPState.encode still uses deterministic fp32.to(target.dtype) rounding (not stochastic_round_), so quantization error remains bounded by ±0.5 ULP. Normalizing with _log_ulp(narrow) here doubles the correction quantization step and reduces ECC reconstruction precision for every newly written state (including fp16 ECC modes, which never use stochastic rounding), introducing avoidable accuracy loss.
Useful? React with 👍 / 👎.
|
Ah, @ClashLuke I wrote the PR before the recent RNE changes, so i think this is not fully correct after rebasing. Both strategies work, the only difference is whether you prefer the main buffers to be biased or not. My intuition is that RNE+ULP/2 has 1 more bit of effective precision |
|
Thank you for the fix and the detailed analysis on the ULP bound. |
With round to nearest the quantization error is bounded by ±1/2 ULP, hence the logULP-1, however when doing stochastic rounding the error can be larger. E.g. if x is representable in FP32 and BF16, then x+eps will be rounded up with some low probability, and the error between the rounded value and x will be 1 ULP instead of 1/2.
So, to handle stochastic rounding correctly (the default in HeavyBall), the ECC needs to be computed using ±1ULP.
Tested with the
precision_toy.pyscript from the blogpost