Fix FusedObsFakeQuant dtype handling for bf16/fp16 and bool inputs#3152
Fix FusedObsFakeQuant dtype handling for bf16/fp16 and bool inputs#3152kdrozd-dev wants to merge 8 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes XPU parity issues in fused_moving_avg_obs_fake_quant by aligning dtype handling with CUDA, especially for bool observer_on/fake_quant_on and bf16/fp16 running min/max buffers.
Changes:
- Extend XPU hypothesis test coverage to include
use_boolandsampled_dtypefor fused moving-avg fake-quant tests. - Convert
observer_on/fake_quant_ontoLongbefore passing into XPU helpers and the per-tensor fake-quant op. - Template SYCL moving-average + qparams kernels on
scalar_tand dispatch over floating + (bf16, fp16) to support non-fp32 running min/max.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| test/xpu/quantization/core/test_workflow_ops_xpu.py | Expands hypothesis args to exercise bool flags and bf16/fp16/fp32 dtype scenarios on XPU. |
| src/ATen/native/quantized/FusedObsFakeQuant.cpp | Casts control tensors to Long before calling SYCL helpers and per-tensor fake-quant op to match CUDA behavior. |
| src/ATen/native/quantized/sycl/FusedObsFakeQuantKernels.cpp | Templates moving-average and qparams kernels and dispatches on scalar_t to support bf16/fp16 running min/max. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| scalar_t curr_min = x_min[i]; | ||
| scalar_t curr_max = x_max[i]; | ||
|
|
||
| float adjusted_min = std::isinf(running_min[i]) | ||
| scalar_t averaging_const_t = static_cast<scalar_t>(averaging_const); | ||
|
|
There was a problem hiding this comment.
MovingAverageMinMax now performs the moving-average update in scalar_t (including averaging_const cast to scalar_t). For Half/BFloat16 this can significantly reduce precision vs the prior float math, and can even make averaging_const_t underflow to 0 for small constants, changing the algorithm’s behavior. Consider doing the computation in at::opmath_type<scalar_t> (or float) and only casting back to scalar_t when writing running_min/max.
| AT_DISPATCH_FLOATING_TYPES_AND2( | ||
| at::kBFloat16, at::kHalf, x.scalar_type(), "MovingAverageMinMax", [&] { | ||
| scalar_t* x_min_data = x_min.data_ptr<scalar_t>(); | ||
| scalar_t* x_max_data = x_max.data_ptr<scalar_t>(); | ||
| scalar_t* running_min_data = running_min.data_ptr<scalar_t>(); |
There was a problem hiding this comment.
_calculate_moving_average dispatches scalar_t based on x.scalar_type(), but then unconditionally does running_min.data_ptr<scalar_t>()/running_max.data_ptr<scalar_t>(). Since this function doesn’t enforce that running_min/running_max have the same dtype as x, this can misinterpret the backing storage and crash or produce wrong results when the stats tensors are float (common for observers) while x is bf16/fp16. Consider dispatching on running_min.scalar_type() (and checking running_max matches), or adding a TORCH_CHECK that running_min/running_max dtype equals x (and documenting that contract), or explicitly casting the reduction outputs/stats tensors to a consistent dtype before taking data_ptr.
Performance outliers, please check!
|
|
Aligns with upstream change: https://github.com/pytorch/pytorch/pull/162620/changes |
Fixes #1819
Motivation
After adding the missing
use_boolandsampled_dtypearguments to the XPU test, twofused_moving_avg_obs_fake_quanttest cases fail with dtype mismatches:test_fused_obs_fake_quant_moving_avg_per_channel_xpu—expected scalar type Long but found Booltest_fused_obs_fake_quant_moving_avg_xpu—expected scalar type Float but found BFloat16Root cause: the XPU kernel differs from CUDA in two ways:
observer_onandfake_quant_onare passed to SYCL kernels without.to(at::kLong), so bool-typed tensors crash ondata_ptr<int64_t>().running_min/running_maxare hardcoded tofloat*in the SYCL kernel, so bf16/fp16 tensors crash ondata_ptr<float>().Changes
FusedObsFakeQuant.cpp: Add.to(at::kLong)forobserver_onandfake_quant_onbefore passing to SYCL kernels and_fake_quantize_per_tensor_affine_cachemask_tensor_qparams, matching the CUDA kernel.FusedObsFakeQuantKernels.cpp: TemplateMovingAverageMinMax,CalculateMovingAverageKernelFunctor,ChooseQuantizationParamsKernelImpl, andCalcMovingAvgQparamsHelperKernelFunctoronscalar_t. Dispatch viaAT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, ...)to support bf16/fp16 running min/max, matching CUDA's type coverage.