Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions auto_round/utils/weight_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,25 @@ def detect_layer(self, module: torch.nn.Module) -> bool:
return "Float" in compressor_name
return True

if hasattr(module, "quantization_scheme"):
from compressed_tensors.quantization.utils import is_module_quantized # pylint: disable=E0401

if is_module_quantized(module) and module.quantization_status.value == "compressed":
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits == 8
and q_scheme.weights.type == "float"
and q_scheme.input_activations.num_bits == 8
and q_scheme.input_activations.type == "float"
):
return True

# Check for FP8Linear layer type
if module.__class__.__name__ == "FP8Linear":
return True

# Fallback: Check for FP8 dtype (for torch.nn.Linear with FP8 weights)
if type(module) == torch.nn.Linear and module.weight is not None:
if type(module) == torch.nn.Linear and getattr(module, "weight", None) is not None:
if str(module.weight.dtype).startswith("torch.float8"):
return True

Expand All @@ -506,6 +519,12 @@ def convert_layer(
to_cpu: bool = False,
) -> torch.nn.Module:
"""Convert a single FP8/CompressedLinear layer to a standard Linear layer."""
if hasattr(layer, "quantization_scheme") and layer.__class__.__name__ == "Linear":
from compressed_tensors.compressors.base import decompress_module # pylint: disable=E0401

decompress_module(layer)
return layer

from auto_round.schemes import QuantizationScheme
from auto_round.utils.device import is_gaudi2

Expand Down Expand Up @@ -562,11 +581,24 @@ class MXFP4Handler(WeightTypeHandler):

def detect_layer(self, module: torch.nn.Module) -> bool:
"""Check if a module is an MXFP4 CompressedLinear layer."""
if module.__class__.__name__ != "CompressedLinear":
return False
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "MXFP4" in compressor_name
if module.__class__.__name__ == "CompressedLinear":
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "MXFP4" in compressor_name
if hasattr(module, "quantization_scheme"):
from compressed_tensors.quantization.utils import is_module_quantized # pylint: disable=E0401

if is_module_quantized(module) and module.quantization_status.value == "compressed":
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits == 4
and q_scheme.weights.type == "float"
and q_scheme.weights.group_size == 32
and q_scheme.input_activations.num_bits == 4
and q_scheme.input_activations.type == "float"
and q_scheme.input_activations.group_size == 32
):
return True
return False

def convert_layer(
Expand Down Expand Up @@ -638,11 +670,24 @@ class MXFP8Handler(WeightTypeHandler):

def detect_layer(self, module: torch.nn.Module) -> bool:
"""Check if a module is an MXFP8 CompressedLinear layer."""
if module.__class__.__name__ != "CompressedLinear":
return False
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "MXFP8" in compressor_name
if module.__class__.__name__ == "CompressedLinear":
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "MXFP8" in compressor_name
if hasattr(module, "quantization_scheme"):
from compressed_tensors.quantization.utils import is_module_quantized # pylint: disable=E0401

if is_module_quantized(module) and module.quantization_status.value == "compressed":
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits == 8
and q_scheme.weights.type == "float"
and q_scheme.weights.group_size == 32
and q_scheme.input_activations.num_bits == 8
and q_scheme.input_activations.type == "float"
and q_scheme.input_activations.group_size == 32
):
return True
return False

def convert_layer(
Expand Down Expand Up @@ -710,11 +755,24 @@ class NVFP4Handler(WeightTypeHandler):

def detect_layer(self, module: torch.nn.Module) -> bool:
"""Check if a module is an NVFP4 CompressedLinear layer."""
if module.__class__.__name__ != "CompressedLinear":
return False
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "NVFP4" in compressor_name
if module.__class__.__name__ == "CompressedLinear":
if hasattr(module, "compressor") and module.compressor is not None:
compressor_name = module.compressor.__class__.__name__
return "NVFP4" in compressor_name
if hasattr(module, "quantization_scheme"):
from compressed_tensors.quantization.utils import is_module_quantized # pylint: disable=E0401

if is_module_quantized(module) and module.quantization_status.value == "compressed":
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits == 4
and q_scheme.weights.type == "float"
and q_scheme.weights.group_size == 16
and q_scheme.input_activations.num_bits == 4
and q_scheme.input_activations.type == "float"
and q_scheme.input_activations.group_size == 16
):
return True
return False

def convert_layer(
Expand All @@ -725,6 +783,12 @@ def convert_layer(
to_cpu: bool = False,
) -> torch.nn.Module:
"""Convert an NVFP4 CompressedLinear layer to a standard Linear layer."""
if hasattr(layer, "quantization_scheme") and layer.__class__.__name__ == "Linear":
from compressed_tensors.compressors.base import decompress_module # pylint: disable=E0401

decompress_module(layer)
return layer

from auto_round.schemes import QuantizationScheme
from auto_round.utils.device import is_gaudi2

Expand Down
27 changes: 18 additions & 9 deletions test/test_cpu/advanced/test_low_precision_input_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,50 @@ class TestCompressedTensor:
mxfp4_model_path = "QuixiAI/Llama-3.2-1B-MXFP4"
fp8_block_model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"

@pytest.mark.skip(reason="CompressedLinear removed in compressed_tensors PR #610, see #1578")
def test_fp8_block(self):
model = get_tiny_model(get_model_path(self.fp8_block_model_path))
assert (
type(model.model.layers[0].mlp.up_proj).__name__ == "CompressedLinear"
model.model.layers[0].mlp.up_proj.weight.dtype == torch.float8_e4m3fn
), "Original weight is not in FP8 format"
assert hasattr(
model.model.layers[0].mlp.up_proj, "quantization_scheme"
), "Model does not contain CompressedLinear layers"
detected_types = check_and_mark_quantized_module(model)
assert ModuleWeightType.FP8 in detected_types
model = convert_module_to_hp_if_necessary(model)
assert (
type(model.model.layers[0].mlp.up_proj) is torch.nn.Linear
model.model.layers[0].mlp.up_proj.weight.dtype == torch.bfloat16
), "CompressedLinear layer was not converted to Linear"

@pytest.mark.skip(reason="CompressedLinear removed in compressed_tensors PR #610, see #1578")
@pytest.mark.skip(
reason="NVFP4 models are currently not supported due to issues with the compressed_tensors library. See https://github.com/vllm-project/compressed-tensors/issues/642"
)
def test_nvfp4(self):
model = get_tiny_model(get_model_path(self.nvfp4_model_path))
assert (
type(model.model.layers[0].mlp.up_proj).__name__ == "CompressedLinear"
model.model.layers[0].mlp.up_proj.weight_packed.dtype == torch.uint8
), "Original weight is not in FP8 format"
assert hasattr(
model.model.layers[0].mlp.up_proj, "quantization_scheme"
), "Model does not contain CompressedLinear layers"
detected_types = check_and_mark_quantized_module(model)
assert ModuleWeightType.NVFP4 in detected_types
model = convert_module_to_hp_if_necessary(model)
assert (
type(model.model.layers[0].mlp.up_proj) is torch.nn.Linear
model.model.layers[0].mlp.up_proj.weight.dtype == torch.bfloat16
), "CompressedLinear layer was not converted to Linear"

@pytest.mark.skip(reason="CompressedLinear removed in compressed_tensors PR #610, see #1578")
def test_mxfp4(self):
model = get_tiny_model(get_model_path(self.mxfp4_model_path))
assert (
type(model.model.layers[0].mlp.up_proj).__name__ == "CompressedLinear"
model.model.layers[0].mlp.up_proj.weight_packed.dtype == torch.uint8
), "Original weight is not in FP8 format"
assert hasattr(
model.model.layers[0].mlp.up_proj, "quantization_scheme"
), "Model does not contain CompressedLinear layers"
detected_types = check_and_mark_quantized_module(model)
assert ModuleWeightType.MXFP4 in detected_types
model = convert_module_to_hp_if_necessary(model)
assert (
type(model.model.layers[0].mlp.up_proj) is torch.nn.Linear
model.model.layers[0].mlp.up_proj.weight.dtype == torch.bfloat16
), "CompressedLinear layer was not converted to Linear"
Loading