Skip to content

[Bug]: Quantized MXFP model export in llmc format failed with inference #1657

@chensuyue

Description

@chensuyue

Problem Description

Quantized and saved the mxfp4/mxfp8 model in llmc format, quantize pass, but inference with transformers failed.

Reproduction Steps

Run with PR, intel/neural-compressor#2441
export CUDA_VISIBLE_DEVICES=5
bash run_quant.sh --topology=Llama-3.1-8B --dtype=mxfp8 --input_model=/models/Llama-3.1-8B-Instruct --output_model=/data3/jenkins/saved_models/Llama-3.1-8B_mxfp8_llmc --export_format=llm_compressor
bash run_quant.sh --topology=Llama-3.1-8B --dtype=mxfp4 --input_model=/models/Llama-3.1-8B-Instruct --output_model=/data3/jenkins/saved_models/Llama-3.1-8B_mxfp4_llmc --export_format=llm_compressor

python run_prompt.py --saved_model_path=/data3/jenkins/saved_models/Llama-3.1-8B_mxfp8_llmc
python run_prompt.py --saved_model_path=/data3/jenkins/saved_models/Llama-3.1-8B_mxfp4_llmc

Environment Information

HW: A100 1 card.
Main SW list:
torch 2.9.0
transformers 4.57.6
compressed-tensors 0.14.0.1
auto-round de0650e

Error Logs

File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                          
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl                 
    return forward_call(*args, **kwargs)                                                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)                                                                                                                                                                                                     
           ^^^^^^^^^^^^^^^^^^^^^          
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 294, in forward
    hidden_states, _ = self.self_attn(                                                                                                                                                                                               
                       ^^^^^^^^^^^^^^^                                                                                                                                                                                               
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                          
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl                 
    return forward_call(*args, **kwargs)                                                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func         
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 236, in forward
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/lifecycle/forward.py", line 359, in wrapped_forward
    input_ = forward_quantize(self, input_, "input", scheme.input_activations)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/lifecycle/forward.py", line 416, in forward_quantize
    scale, zero_point = compute_dynamic_scales_and_zp(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/utils/helpers.py", line 194, in compute_dynamic_scales_and_zp
    return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/utils/helpers.py", line 82, in calculate_qparams
    scales = generate_mxfp4_scales(x=max_val_pos)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/utils/mxfp4_utils.py", line 100, in generate_mxfp4_scales
    scale_power_2 = round_to_power_2(x)
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/uttest/miniforge3/envs/jenkins-key-model-llama/lib/python3.12/site-packages/compressed_tensors/quantization/utils/mxfp4_utils.py", line 69, in round_to_power_2
    assert scale_dtype is torch.float32
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Additional Context

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions