From e9456e056a8c207834aa09d0d66bc2b0d364a467 Mon Sep 17 00:00:00 2001 From: Jason Date: Sun, 22 Mar 2026 01:03:59 +0100 Subject: [PATCH] fix: support torchao >= 0.16.0 by importing renamed CamelCase Config classes torchao 0.15.0 deprecated the snake_case quantization functions (int4_weight_only, float8_weight_only, etc.) with a deprecation warning, and 0.16.0 removed them entirely, replacing them with CamelCase Config classes (Int4WeightOnlyConfig, Float8WeightOnlyConfig, etc.). Add a version guard in TorchAoConfig._get_torchao_quant_type_to_method: - torchao >= 0.16.0: import new CamelCase Config classes, aliased to the old snake_case names so the rest of the method remains unchanged - torchao < 0.16.0: keep importing the old snake_case functions as before Fixes #13286 --- .../quantizers/quantization_config.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 9a467e6b21ee..b59820e452fd 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -626,16 +626,30 @@ def _get_torchao_quant_type_to_method(cls): if is_torchao_available(): # TODO(aryan): Support sparsify - from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, - ) + # torchao >= 0.16.0 renamed snake_case functions to CamelCase Config classes. + # Use the new API when available, falling back to the old API for older versions. + if is_torchao_version(">=", "0.16.0"): + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig as float8_static_activation_float8_weight, + Float8WeightOnlyConfig as float8_weight_only, + Int4WeightOnlyConfig as int4_weight_only, + Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight, + Int8WeightOnlyConfig as int8_weight_only, + UIntXWeightOnlyConfig as uintx_weight_only, + ) + else: + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) if is_torchao_version("<=", "0.14.1"): from torchao.quantization import fpx_weight_only