diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 9a467e6b21ee..b59820e452fd 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -626,16 +626,30 @@ def _get_torchao_quant_type_to_method(cls): if is_torchao_available(): # TODO(aryan): Support sparsify - from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, - ) + # torchao >= 0.16.0 renamed snake_case functions to CamelCase Config classes. + # Use the new API when available, falling back to the old API for older versions. + if is_torchao_version(">=", "0.16.0"): + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig as float8_static_activation_float8_weight, + Float8WeightOnlyConfig as float8_weight_only, + Int4WeightOnlyConfig as int4_weight_only, + Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight, + Int8WeightOnlyConfig as int8_weight_only, + UIntXWeightOnlyConfig as uintx_weight_only, + ) + else: + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) if is_torchao_version("<=", "0.14.1"): from torchao.quantization import fpx_weight_only