Skip to content

Fix FusedObsFakeQuant dtype handling for bf16/fp16 and bool inputs#3152

Open
kdrozd-dev wants to merge 8 commits intomainfrom
kdrozd/fix-fused-obs-fake-quant-test
Open

Fix FusedObsFakeQuant dtype handling for bf16/fp16 and bool inputs#3152
kdrozd-dev wants to merge 8 commits intomainfrom
kdrozd/fix-fused-obs-fake-quant-test

Conversation

@kdrozd-dev
Copy link
Copy Markdown
Contributor

Fixes #1819

Motivation

After adding the missing use_bool and sampled_dtype arguments to the XPU test, two fused_moving_avg_obs_fake_quant test cases fail with dtype mismatches:

  • test_fused_obs_fake_quant_moving_avg_per_channel_xpuexpected scalar type Long but found Bool
  • test_fused_obs_fake_quant_moving_avg_xpuexpected scalar type Float but found BFloat16

Root cause: the XPU kernel differs from CUDA in two ways:

  1. observer_on and fake_quant_on are passed to SYCL kernels without .to(at::kLong), so bool-typed tensors crash on data_ptr<int64_t>().
  2. running_min/running_max are hardcoded to float* in the SYCL kernel, so bf16/fp16 tensors crash on data_ptr<float>().

Changes

  • FusedObsFakeQuant.cpp: Add .to(at::kLong) for observer_on and fake_quant_on before passing to SYCL kernels and _fake_quantize_per_tensor_affine_cachemask_tensor_qparams, matching the CUDA kernel.
  • FusedObsFakeQuantKernels.cpp: Template MovingAverageMinMax, CalculateMovingAverageKernelFunctor, ChooseQuantizationParamsKernelImpl, and CalcMovingAvgQparamsHelperKernelFunctor on scalar_t. Dispatch via AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, ...) to support bf16/fp16 running min/max, matching CUDA's type coverage.

Copilot AI review requested due to automatic review settings March 24, 2026 10:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes XPU parity issues in fused_moving_avg_obs_fake_quant by aligning dtype handling with CUDA, especially for bool observer_on/fake_quant_on and bf16/fp16 running min/max buffers.

Changes:

  • Extend XPU hypothesis test coverage to include use_bool and sampled_dtype for fused moving-avg fake-quant tests.
  • Convert observer_on / fake_quant_on to Long before passing into XPU helpers and the per-tensor fake-quant op.
  • Template SYCL moving-average + qparams kernels on scalar_t and dispatch over floating + (bf16, fp16) to support non-fp32 running min/max.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
test/xpu/quantization/core/test_workflow_ops_xpu.py Expands hypothesis args to exercise bool flags and bf16/fp16/fp32 dtype scenarios on XPU.
src/ATen/native/quantized/FusedObsFakeQuant.cpp Casts control tensors to Long before calling SYCL helpers and per-tensor fake-quant op to match CUDA behavior.
src/ATen/native/quantized/sycl/FusedObsFakeQuantKernels.cpp Templates moving-average and qparams kernels and dispatches on scalar_t to support bf16/fp16 running min/max.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings March 24, 2026 11:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +41 to +45
scalar_t curr_min = x_min[i];
scalar_t curr_max = x_max[i];

float adjusted_min = std::isinf(running_min[i])
scalar_t averaging_const_t = static_cast<scalar_t>(averaging_const);

Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MovingAverageMinMax now performs the moving-average update in scalar_t (including averaging_const cast to scalar_t). For Half/BFloat16 this can significantly reduce precision vs the prior float math, and can even make averaging_const_t underflow to 0 for small constants, changing the algorithm’s behavior. Consider doing the computation in at::opmath_type<scalar_t> (or float) and only casting back to scalar_t when writing running_min/max.

Copilot uses AI. Check for mistakes.
Comment on lines +126 to +130
AT_DISPATCH_FLOATING_TYPES_AND2(
at::kBFloat16, at::kHalf, x.scalar_type(), "MovingAverageMinMax", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_calculate_moving_average dispatches scalar_t based on x.scalar_type(), but then unconditionally does running_min.data_ptr<scalar_t>()/running_max.data_ptr<scalar_t>(). Since this function doesn’t enforce that running_min/running_max have the same dtype as x, this can misinterpret the backing storage and crash or produce wrong results when the stats tensors are float (common for observers) while x is bf16/fp16. Consider dispatching on running_min.scalar_type() (and checking running_max matches), or adding a TORCH_CHECK that running_min/running_max dtype equals x (and documenting that contract), or explicitly casting the reduction outputs/stats tensors to a consistent dtype before taking data_ptr.

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

Performance outliers, please check!

  • 🔴 [-1, 80%), should be regression
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
huggingface_float16_training DistilBertForMaskedLM 0.721518 0.711041
huggingface_float16_training BartForCausalLM 0.737013 0.720776
huggingface_bfloat16_training DistilBertForMaskedLM 0.763197 0.721157
huggingface_float16_training PegasusForCausalLM 0.718918 0.726372
huggingface_bfloat16_training RobertaForCausalLM 0.762177 0.726823
huggingface_float16_training MBartForCausalLM 0.753265 0.730381
huggingface_float16_training DistillGPT2 0.701985 0.736635
huggingface_bfloat16_training BartForCausalLM 0.772665 0.737958
huggingface_float16_training TrOCRForCausalLM 0.709859 0.738806
huggingface_bfloat16_training MBartForCausalLM 0.772593 0.739028
huggingface_float16_training RobertaForCausalLM 0.732433 0.743094
huggingface_bfloat16_training BertForMaskedLM 0.770050 0.744690
huggingface_bfloat16_training DistillGPT2 0.752220 0.748556
huggingface_float16_training PLBartForCausalLM 0.747021 0.755293
torchbench_bfloat16_training mnasnet1_0 0.964247 0.757236
huggingface_bfloat16_training PegasusForCausalLM 0.762049 0.759374
huggingface_float16_training BertForMaskedLM 0.731741 0.759760
huggingface_bfloat16_training TrOCRForCausalLM 0.773608 0.761799
huggingface_bfloat16_training PLBartForCausalLM 0.777813 0.765756
huggingface_bfloat16_training ElectraForCausalLM 0.785216 0.766246
huggingface_float16_training MegatronBertForCausalLM 0.779540 0.768168
huggingface_float16_training YituTechConvBert 0.709883 0.773631
huggingface_float16_training OPTForCausalLM 0.804749 0.777220
huggingface_float16_training XGLMForCausalLM 0.770287 0.777361
huggingface_bfloat16_training OPTForCausalLM 0.837103 0.778041
huggingface_float16_training XLNetLMHeadModel 0.698486 0.778531
huggingface_bfloat16_training MegatronBertForCausalLM 0.824654 0.781909
huggingface_bfloat16_training T5Small 0.791417 0.782196
huggingface_float16_training T5Small 0.764197 0.782534
huggingface_float16_training T5ForConditionalGeneration 0.767568 0.782913
huggingface_bfloat16_training LayoutLMForMaskedLM 0.764797 0.783911
huggingface_bfloat16_training T5ForConditionalGeneration 0.793016 0.787365
huggingface_float16_training LayoutLMForMaskedLM 0.729120 0.787475
huggingface_bfloat16_training YituTechConvBert 0.738646 0.788666
huggingface_float16_training AllenaiLongformerBase 0.653033 0.792223
huggingface_bfloat16_training GPT2ForSequenceClassification 0.780081 0.795420
huggingface_float16_training ElectraForCausalLM 0.743803 0.805850
huggingface_float16_training GPT2ForSequenceClassification 0.740645 0.807811
torchbench_bfloat16_training densenet121 0.790149 0.819276
huggingface_float16_training AlbertForMaskedLM 0.764744 0.825174
  • 🟡 [80%, 90%), may be fluctuations
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
torchbench_bfloat16_training resnet18 0.915819 0.800610
huggingface_bfloat16_training DebertaV2ForMaskedLM 0.855793 0.801020
huggingface_bfloat16_training XGLMForCausalLM 0.816172 0.806046
huggingface_bfloat16_training M2M100ForConditionalGeneration 0.854172 0.814082
torchbench_bfloat16_training dcgan 0.851483 0.819891
huggingface_float16_training DebertaV2ForMaskedLM 0.861442 0.827771
huggingface_float16_training BlenderbotForCausalLM 0.816512 0.828027
huggingface_float16_training M2M100ForConditionalGeneration 0.846012 0.830743
torchbench_bfloat16_training mobilenet_v3_large 0.951522 0.841212
huggingface_bfloat16_training AlbertForMaskedLM 0.802540 0.844707
huggingface_bfloat16_training BlenderbotForCausalLM 0.835404 0.853859
torchbench_bfloat16_training resnext50_32x4d 0.952660 0.865722
huggingface_bfloat16_training GoogleFnet 0.840805 0.908283
huggingface_float16_training GoogleFnet 0.840881 0.908894

@kdrozd-dev
Copy link
Copy Markdown
Contributor Author

Aligns with upstream change: https://github.com/pytorch/pytorch/pull/162620/changes

@kdrozd-dev kdrozd-dev requested a review from CuiYifeng March 27, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BMG-Windows][PT2.8]Torch-xpu-ops UT got TypeError

3 participants