Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ NVIDIA Model Optimizer Changelog
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.

**Bug Fixes**

Expand Down Expand Up @@ -48,7 +49,6 @@ NVIDIA Model Optimizer Changelog
- Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
- Add support for block-granular RHT for non-power-of-2 dimensions.
- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes.
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.

**Deprecations**

Expand Down
20 changes: 18 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn as nn

import modelopt.torch.opt as mto
from modelopt.torch.quantization.config import RotateConfig
from modelopt.torch.quantization.conversion import quantizer_state
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
from modelopt.torch.quantization.utils import get_quantizer_state_dict
Expand All @@ -28,6 +29,15 @@
__all__ = ["export_hf_vllm_fq_checkpoint"]


def disable_rotate(quantizer: TensorQuantizer):
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
if isinstance(quantizer._rotate, RotateConfig):
return RotateConfig(enable=False)
if isinstance(quantizer._rotate, dict): # backward compat: old checkpoints stored a dict
return dict(quantizer._rotate, enable=False)
return False


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
Expand Down Expand Up @@ -104,6 +114,8 @@ def export_hf_vllm_fq_checkpoint(
# dict, then re-enable. The _disabled=True flag is captured in modelopt_state
# so that on vLLM reload weight quantizers stay off while input/output/
# attention quantizers remain active.
# Rotation is also cleared: the weight was already folded with rotation applied,
# so if fold_weight is called on reload it must not re-rotate the exported weight.
wqs_to_restore = []
for _, module in model.named_modules():
if isinstance(module, QuantModule):
Expand All @@ -114,7 +126,10 @@ def export_hf_vllm_fq_checkpoint(
and quantizer.is_enabled
):
quantizer.disable()
wqs_to_restore.append(quantizer)
orig_rotate = quantizer._rotate
if quantizer.rotate_is_enabled:
quantizer._rotate = disable_rotate(quantizer)
wqs_to_restore.append((quantizer, orig_rotate))

quantizer_state_dict = get_quantizer_state_dict(model)
for key in list(quantizer_state_dict):
Expand Down Expand Up @@ -149,5 +164,6 @@ def export_hf_vllm_fq_checkpoint(
# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)

for wq in wqs_to_restore:
for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
Loading