diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 7907c79bd6..fe066293e9 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -764,7 +764,7 @@ def run_search(self): def _get_auto_quantize_score(grad_output, output_diff): x = grad_output.float() * output_diff.float() - return x.to(torch.float64).square().sum() + return x.clamp(-1e10, 1e10).square().sum() def _add_auto_quantize_score(grad_output, output_diff, score_tensor):