Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,16 +626,30 @@ def _get_torchao_quant_type_to_method(cls):

if is_torchao_available():
# TODO(aryan): Support sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
# torchao >= 0.16.0 renamed snake_case functions to CamelCase Config classes.
# Use the new API when available, falling back to the old API for older versions.
if is_torchao_version(">=", "0.16.0"):
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight,
Float8StaticActivationFloat8WeightConfig as float8_static_activation_float8_weight,
Float8WeightOnlyConfig as float8_weight_only,
Int4WeightOnlyConfig as int4_weight_only,
Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight,
Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight,
Int8WeightOnlyConfig as int8_weight_only,
UIntXWeightOnlyConfig as uintx_weight_only,
)
else:
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)

if is_torchao_version("<=", "0.14.1"):
from torchao.quantization import fpx_weight_only
Expand Down