From e3c3b14d967ef00f37629ba98cafc68e27a78236 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 30 Mar 2026 19:53:00 +0000 Subject: [PATCH 1/3] Bug fix: rotation disabled during export for folded weights Signed-off-by: Kinjal Patel --- modelopt/torch/export/plugins/vllm_fakequant_hf.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 5d6655d1fc..2d259051df 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -104,6 +104,8 @@ def export_hf_vllm_fq_checkpoint( # dict, then re-enable. The _disabled=True flag is captured in modelopt_state # so that on vLLM reload weight quantizers stay off while input/output/ # attention quantizers remain active. + # Rotation is also cleared: the weight was already folded with rotation applied, + # so if fold_weight is called on reload it must not re-rotate the exported weight. wqs_to_restore = [] for _, module in model.named_modules(): if isinstance(module, QuantModule): @@ -114,7 +116,10 @@ def export_hf_vllm_fq_checkpoint( and quantizer.is_enabled ): quantizer.disable() - wqs_to_restore.append(quantizer) + orig_rotate = quantizer._rotate + if quantizer.rotate_is_enabled: + quantizer._rotate = False + wqs_to_restore.append((quantizer, orig_rotate)) quantizer_state_dict = get_quantizer_state_dict(model) for key in list(quantizer_state_dict): @@ -149,5 +154,6 @@ def export_hf_vllm_fq_checkpoint( # Step 3: Save HF weights using the pre-built folded state dict. model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - for wq in wqs_to_restore: + for wq, orig_rotate in wqs_to_restore: wq.enable() + wq._rotate = orig_rotate From a755a5470d64966a4366273f624e5fa410a845d6 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 31 Mar 2026 20:46:38 +0000 Subject: [PATCH 2/3] minor Signed-off-by: Kinjal Patel --- modelopt/torch/export/plugins/vllm_fakequant_hf.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 2d259051df..1908354a0a 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -20,6 +20,7 @@ import torch.nn as nn import modelopt.torch.opt as mto +from modelopt.torch.quantization.config import RotateConfig from modelopt.torch.quantization.conversion import quantizer_state from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer from modelopt.torch.quantization.utils import get_quantizer_state_dict @@ -28,6 +29,15 @@ __all__ = ["export_hf_vllm_fq_checkpoint"] +def disable_rotate(quantizer: TensorQuantizer): + """Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type.""" + if isinstance(quantizer._rotate, RotateConfig): + return RotateConfig(enable=False) + if isinstance(quantizer._rotate, dict): # backward compat: old checkpoints stored a dict + return dict(quantizer._rotate, enable=False) + return False + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -118,7 +128,7 @@ def export_hf_vllm_fq_checkpoint( quantizer.disable() orig_rotate = quantizer._rotate if quantizer.rotate_is_enabled: - quantizer._rotate = False + quantizer._rotate = disable_rotate(quantizer) wqs_to_restore.append((quantizer, orig_rotate)) quantizer_state_dict = get_quantizer_state_dict(model) From a4b8232e799a0afc154d2eeb2f57f340dee0b929 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 1 Apr 2026 17:59:35 +0000 Subject: [PATCH 3/3] minor Signed-off-by: Kinjal Patel --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d52ad0c2ad..d007f5e28c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,7 @@ NVIDIA Model Optimizer Changelog - Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. +- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. **Bug Fixes** @@ -48,7 +49,6 @@ NVIDIA Model Optimizer Changelog - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. -- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. **Deprecations**