From 4dfbbdb6f1a8c0875eadd01d8b6d2622829c464b Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Tue, 31 Mar 2026 14:37:36 +0530 Subject: [PATCH 1/3] Added ModelOpt graph surgery pass in Olive Signed-off-by: Hrishith Thadicherla --- olive/olive_config.json | 9 + olive/passes/onnx/nvmo_graph_surgery.py | 173 ++++++++ test/passes/onnx/test_nvmo_graph_surgery.py | 427 ++++++++++++++++++++ 3 files changed, 609 insertions(+) create mode 100644 olive/passes/onnx/nvmo_graph_surgery.py create mode 100644 test/passes/onnx/test_nvmo_graph_surgery.py diff --git a/olive/olive_config.json b/olive/olive_config.json index 66613a779..12f74404b 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -239,6 +239,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ] }, + "NVModelOptGraphSurgery": { + "module_path": "olive.passes.onnx.nvmo_graph_surgery.NVModelOptGraphSurgery", + "supported_providers": [ "*" ], + "supported_accelerators": [ "*" ], + "supported_precisions": [ "*" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "extra_dependencies": [ "nvmo" ] + }, "NVModelOptQuantization": { "module_path": "olive.passes.onnx.nvmo_quantization.NVModelOptQuantization", "supported_providers": [ "CUDAExecutionProvider" ], diff --git a/olive/passes/onnx/nvmo_graph_surgery.py b/olive/passes/onnx/nvmo_graph_surgery.py new file mode 100644 index 000000000..646c8dd84 --- /dev/null +++ b/olive/passes/onnx/nvmo_graph_surgery.py @@ -0,0 +1,173 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +import logging +import os +import shutil +import tempfile +from pathlib import Path + +import onnx +from onnx.onnx_pb import ModelProto + +from olive.hardware.accelerator import AcceleratorSpec +from olive.model import ONNXModelHandler +from olive.model.utils import resolve_onnx_path +from olive.passes import Pass +from olive.passes.onnx.common import model_proto_to_olive_model +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +class NVModelOptGraphSurgery(Pass): + """Perform graph surgeries on ONNX models using NVIDIA ModelOpt. + + This pass provides a scalable interface to all graph surgery operations + available in ModelOpt. It uses ModelOpt's run_graph_surgery dispatcher, + so any new surgery added to ModelOpt is automatically available here + without code changes. + + Use get_available_surgeries() from modelopt.onnx.graph_surgery to see + all available surgery types. + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "surgery_type": PassConfigParam( + type_=str, + required=True, + description=( + "Name of the graph surgery to perform. " + "Examples: 'replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq'. " + "Run modelopt.onnx.graph_surgery.get_available_surgeries() to see all options." + ), + ), + "surgery_params": PassConfigParam( + type_=dict, + default_value={}, + description=( + "Dictionary of surgery-specific parameters. " + "These are passed directly to the ModelOpt surgery function as keyword arguments. " + "Refer to ModelOpt documentation for each surgery's parameters." + ), + ), + } + + @classmethod + def validate_config( + cls, + config: type[BasePassConfig], + accelerator_spec: AcceleratorSpec, + ) -> bool: + if not super().validate_config(config, accelerator_spec): + return False + + try: + from modelopt.onnx.graph_surgery import get_available_surgeries + except ImportError: + logger.exception("modelopt is not installed. Install with 'pip install nvidia_modelopt'.") + return False + + surgery_type = config.surgery_type + available = get_available_surgeries() + if surgery_type not in available: + logger.error("Unknown surgery type: '%s'. Available: %s", surgery_type, available) + return False + + return True + + def _run_for_config( + self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str + ) -> ONNXModelHandler: + """Run the graph surgery on the model.""" + try: + from modelopt.onnx.graph_surgery import run_graph_surgery + except ImportError: + raise ImportError("modelopt is not installed. Install with 'pip install nvidia_modelopt'.") from None + + surgery_type = config.surgery_type + surgery_params = dict(config.surgery_params or {}) + + logger.info("Starting ModelOpt graph surgery: %s", surgery_type) + logger.debug("Surgery parameters: %s", surgery_params) + + try: + with tempfile.TemporaryDirectory() as temp_dir: + temp_input_path = os.path.join(temp_dir, "input_model.onnx") + temp_output_path = os.path.join(temp_dir, "output_model.onnx") + + # Save input model to temp directory + model_proto = model.load_model() + onnx.save_model( + model_proto, + temp_input_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="input_model.onnx_data", + size_threshold=1024, + ) + + # Call ModelOpt's unified dispatcher + result = run_graph_surgery( + surgery_name=surgery_type, + input_path=temp_input_path, + output_path=temp_output_path, + **surgery_params, + ) + + # Load modified model (without external data — we'll copy the file separately) + if isinstance(result, ModelProto): + modified_model_proto = result + temp_ext_data_file = os.path.join(temp_dir, "output_model.onnx_data") + if os.path.exists(temp_ext_data_file): + modified_model_proto = onnx.load(temp_output_path, load_external_data=False) + else: + modified_model_proto = onnx.load(temp_output_path, load_external_data=False) + + # Check for external data file + temp_ext_data_file = os.path.join(temp_dir, "output_model.onnx_data") + has_external_data = os.path.exists(temp_ext_data_file) + + # Resolve final output path + output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + output_dir = Path(output_model_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + output_ext_data_name = f"{Path(output_model_path).name}.data" + + if has_external_data: + # Copy external data file while temp dir still exists + output_ext_data_path = output_dir / output_ext_data_name + logger.info("Copying external data file to %s", output_ext_data_path) + shutil.copy2(temp_ext_data_file, str(output_ext_data_path)) + + # Update model references and save + from olive.passes.onnx.common import ( + add_version_metadata_to_model_proto, + change_external_data_location, + ) + + change_external_data_location(modified_model_proto, output_ext_data_name) + modified_model_proto = add_version_metadata_to_model_proto(modified_model_proto) + onnx.save_model(modified_model_proto, str(output_model_path)) + + from olive.resource_path import LocalFolder + + return ONNXModelHandler( + model_path=LocalFolder({"path": output_dir}), + onnx_file_name=Path(output_model_path).name, + ) + else: + external_data_config = { + "save_as_external_data": True, + "all_tensors_to_one_file": True, + "external_data_name": output_ext_data_name, + "size_threshold": 1024, + } + return model_proto_to_olive_model(modified_model_proto, output_model_path, external_data_config) + + except Exception: + logger.exception("An error occurred during graph surgery: %s", surgery_type) + raise diff --git a/test/passes/onnx/test_nvmo_graph_surgery.py b/test/passes/onnx/test_nvmo_graph_surgery.py new file mode 100644 index 000000000..e7c17888f --- /dev/null +++ b/test/passes/onnx/test_nvmo_graph_surgery.py @@ -0,0 +1,427 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper, numpy_helper + +from olive.model import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.nvmo_graph_surgery import NVModelOptGraphSurgery + +pytest.importorskip("modelopt", reason="nvidia-modelopt required for graph surgery tests") +pytest.importorskip("transformers", reason="transformers required for GQA surgery tests") + +MODEL_ID = "Qwen/Qwen2.5-0.5B" +VOCAB_SIZE = 64 + +_RNG = np.random.RandomState(42) + + +def _fp16(*shape): + return (_RNG.randn(*shape) * 0.02).astype(np.float16) + + +def _init(name, arr): + return numpy_helper.from_array(arr, name=name) + + +def _get_config(): + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=False) + hidden = cfg.hidden_size + heads = cfg.num_attention_heads + kv = getattr(cfg, "num_key_value_heads", heads) + hd = hidden // heads + return hidden, heads, kv, hd + + +# --------------------------------------------------------------------------- +# Dummy model builders +# --------------------------------------------------------------------------- + + +def _build_attention_model(hidden_size, num_heads, kv_heads, head_dim) -> onnx.ModelProto: + """Build a single-layer LLM with realistic Optimum-style attention. + + Simplified from Model-Optimizer test_gqa_graph_surgery._build_toy_model: + Gather -> LayerNorm -> Q/K/V MatMul -> Reshape -> Transpose + -> KV cache concat -> (GQA repeat) -> K^T -> scale*Q@K^T + -> Add(attn_bias) -> Softmax -> Attn@V -> Transpose -> Reshape + -> o_proj MatMul -> residual Add -> lm_head -> logits + + Attention mask is wired into a causal+padding bias subgraph. + """ + q_dim = num_heads * head_dim + k_dim = kv_heads * head_dim + ap = "/model/layers.0/self_attn" + + nodes, inits = [], [] + + # --- Constants --- + inits.extend( + [ + _init("one_f16", np.array(1.0, dtype=np.float16)), + _init("neg_large_f16", np.array(-1e4, dtype=np.float16)), + _init("axes_0", np.array([0], dtype=np.int64)), + _init("axes_01", np.array([0, 1], dtype=np.int64)), + _init("axes_12", np.array([1, 2], dtype=np.int64)), + _init("trilu_k1", np.array(1, dtype=np.int64)), + ] + ) + + # --- Graph I/O --- + graph_inputs = [ + helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["B", "S"]), + helper.make_tensor_value_info("attention_mask", TensorProto.INT64, ["B", "T"]), + helper.make_tensor_value_info("past_key_values.0.key", TensorProto.FLOAT16, ["B", kv_heads, "P", head_dim]), + helper.make_tensor_value_info("past_key_values.0.value", TensorProto.FLOAT16, ["B", kv_heads, "P", head_dim]), + ] + graph_outputs = [ + helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["B", "S", VOCAB_SIZE]), + helper.make_tensor_value_info("present.0.key", TensorProto.FLOAT16, ["B", kv_heads, "T", head_dim]), + helper.make_tensor_value_info("present.0.value", TensorProto.FLOAT16, ["B", kv_heads, "T", head_dim]), + ] + + # --- Embedding --- + inits.append(_init("model.embed_tokens.weight", _fp16(VOCAB_SIZE, hidden_size))) + nodes.append( + helper.make_node( + "Gather", + ["model.embed_tokens.weight", "input_ids"], + ["/model/embed_tokens/Gather_output_0"], + name="/model/embed_tokens/Gather", + axis=0, + ) + ) + hidden = "/model/embed_tokens/Gather_output_0" + + # --- Causal + padding mask (from attention_mask) --- + nodes.append(helper.make_node("Shape", ["input_ids"], ["ids_shape"], name="/model/pos/Shape")) + nodes.append( + helper.make_node( + "Constant", + [], + ["/model/pos/C1_output_0"], + name="/model/pos/C1", + value=numpy_helper.from_array(np.array(1, dtype=np.int64), name=""), + ) + ) + nodes.append( + helper.make_node( + "Gather", + ["ids_shape", "/model/pos/C1_output_0"], + ["seq_len"], + name="/model/pos/seq_gather", + axis=0, + ) + ) + nodes.append(helper.make_node("Unsqueeze", ["seq_len", "axes_0"], ["seq_1d"], name="/model/causal/unsq")) + nodes.append(helper.make_node("Concat", ["seq_1d", "seq_1d"], ["causal_shape"], name="/model/causal/cat", axis=0)) + nodes.append( + helper.make_node( + "ConstantOfShape", + ["causal_shape"], + ["causal_ones"], + name="/model/causal/fill", + value=numpy_helper.from_array(np.array([1.0], dtype=np.float16), name=""), + ) + ) + nodes.append( + helper.make_node("Trilu", ["causal_ones", "trilu_k1"], ["upper_tri"], name="/model/causal/trilu", upper=1) + ) + nodes.append(helper.make_node("Mul", ["upper_tri", "neg_large_f16"], ["causal_4d_raw"], name="/model/causal/mul")) + nodes.append( + helper.make_node("Unsqueeze", ["causal_4d_raw", "axes_01"], ["causal_4d"], name="/model/causal/unsq4d") + ) + nodes.append(helper.make_node("Cast", ["attention_mask"], ["pad_f16"], name="/model/pad/cast", to=10)) + nodes.append(helper.make_node("Unsqueeze", ["pad_f16", "axes_12"], ["pad_4d"], name="/model/pad/unsq")) + nodes.append(helper.make_node("Sub", ["one_f16", "pad_4d"], ["inv_pad"], name="/model/pad/inv")) + nodes.append(helper.make_node("Mul", ["inv_pad", "neg_large_f16"], ["pad_bias"], name="/model/pad/mul")) + nodes.append(helper.make_node("Add", ["causal_4d", "pad_bias"], ["attn_bias"], name="/model/bias/add")) + + # --- LayerNorm --- + ln_w = "model.layers.0.input_layernorm.weight" + ln_b = "model.layers.0.input_layernorm.bias" + inits.append(_init(ln_w, np.ones(hidden_size, dtype=np.float16))) + inits.append(_init(ln_b, np.zeros(hidden_size, dtype=np.float16))) + ln_out = "/model/layers.0/input_layernorm/Mul_1_output_0" + nodes.append( + helper.make_node( + "LayerNormalization", + [hidden, ln_w, ln_b], + [ln_out], + name="/model/layers.0/input_layernorm/LayerNorm", + axis=-1, + epsilon=1e-5, + ) + ) + + # --- Q / K / V projections --- + inits.extend( + [ + _init("model.layers.0.self_attn.q_proj.weight", _fp16(hidden_size, q_dim)), + _init("model.layers.0.self_attn.k_proj.weight", _fp16(hidden_size, k_dim)), + _init("model.layers.0.self_attn.v_proj.weight", _fp16(hidden_size, k_dim)), + _init("model.layers.0.self_attn.o_proj.weight", _fp16(q_dim, hidden_size)), + ] + ) + for proj, _dim in [("q_proj", q_dim), ("k_proj", k_dim), ("v_proj", k_dim)]: + nodes.append( + helper.make_node( + "MatMul", + [ln_out, f"model.layers.0.self_attn.{proj}.weight"], + [f"{ap}/{proj}/MatMul_output_0"], + name=f"{ap}/{proj}/MatMul", + ) + ) + + # --- Reshape + Transpose to multi-head --- + inits.append(_init(f"{ap}/q_shape", np.array([0, 0, num_heads, head_dim], np.int64))) + inits.append(_init(f"{ap}/kv_shape", np.array([0, 0, kv_heads, head_dim], np.int64))) + for tag, proj, shape_name in [ + ("", "q_proj", "q_shape"), + ("_1", "k_proj", "kv_shape"), + ("_2", "v_proj", "kv_shape"), + ]: + nodes.append( + helper.make_node( + "Reshape", + [f"{ap}/{proj}/MatMul_output_0", f"{ap}/{shape_name}"], + [f"{ap}/Reshape{tag}_output_0"], + name=f"{ap}/Reshape{tag}", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{ap}/Reshape{tag}_output_0"], + [f"{ap}/Transpose{tag}_output_0"], + name=f"{ap}/Transpose{tag}", + perm=[0, 2, 1, 3], + ) + ) + + qt = f"{ap}/Transpose_output_0" + kt = f"{ap}/Transpose_1_output_0" + vt = f"{ap}/Transpose_2_output_0" + + # --- KV cache concat --- + nodes.append( + helper.make_node("Concat", ["past_key_values.0.key", kt], ["present.0.key"], name=f"{ap}/Concat_5", axis=2) + ) + nodes.append( + helper.make_node("Concat", ["past_key_values.0.value", vt], ["present.0.value"], name=f"{ap}/Concat_6", axis=2) + ) + + # --- GQA repeat KV if needed --- + if kv_heads != num_heads: + reps = num_heads // kv_heads + inits.extend( + [ + _init(f"{ap}/rk/exp", np.array([1, reps, 1, 1], np.int64)), + _init(f"{ap}/rk/ax", np.array([2], np.int64)), + _init(f"{ap}/rk/rs", np.array([0, num_heads, -1, head_dim], np.int64)), + ] + ) + for t, src in [("k", "present.0.key"), ("v", "present.0.value")]: + nodes.append( + helper.make_node( + "Unsqueeze", [src, f"{ap}/rk/ax"], [f"{ap}/rk/{t}u"], name=f"{ap}/repeat_kv/{t}_unsqueeze" + ) + ) + nodes.append( + helper.make_node( + "Expand", [f"{ap}/rk/{t}u", f"{ap}/rk/exp"], [f"{ap}/rk/{t}e"], name=f"{ap}/repeat_kv/{t}_expand" + ) + ) + nodes.append( + helper.make_node( + "Reshape", [f"{ap}/rk/{t}e", f"{ap}/rk/rs"], [f"{ap}/rk/{t}r"], name=f"{ap}/repeat_kv/{t}_reshape" + ) + ) + k_final, v_final = f"{ap}/rk/kr", f"{ap}/rk/vr" + else: + k_final, v_final = "present.0.key", "present.0.value" + + # --- Scaled dot-product attention --- + nodes.append( + helper.make_node( + "Transpose", [k_final], [f"{ap}/Transpose_3_output_0"], name=f"{ap}/Transpose_3", perm=[0, 1, 3, 2] + ) + ) + scale_val = float(np.float16(1.0 / head_dim**0.5)) + nodes.append( + helper.make_node( + "Constant", + [], + [f"{ap}/scale_output_0"], + name=f"{ap}/scale", + value=numpy_helper.from_array(np.array(scale_val, dtype=np.float16), name=""), + ) + ) + nodes.append(helper.make_node("Mul", [qt, f"{ap}/scale_output_0"], [f"{ap}/Mul_8_output_0"], name=f"{ap}/Mul_8")) + nodes.append( + helper.make_node( + "MatMul", + [f"{ap}/Mul_8_output_0", f"{ap}/Transpose_3_output_0"], + [f"{ap}/MatMul_output_0"], + name=f"{ap}/MatMul", + ) + ) + nodes.append( + helper.make_node("Add", [f"{ap}/MatMul_output_0", "attn_bias"], [f"{ap}/Add_2_output_0"], name=f"{ap}/Add_2") + ) + nodes.append( + helper.make_node("Softmax", [f"{ap}/Add_2_output_0"], [f"{ap}/Softmax_output_0"], name=f"{ap}/Softmax", axis=-1) + ) + nodes.append( + helper.make_node( + "MatMul", [f"{ap}/Softmax_output_0", v_final], [f"{ap}/MatMul_1_output_0"], name=f"{ap}/MatMul_1" + ) + ) + + # --- Transpose + Reshape back --- + nodes.append( + helper.make_node( + "Transpose", + [f"{ap}/MatMul_1_output_0"], + [f"{ap}/Transpose_4_output_0"], + name=f"{ap}/Transpose_4", + perm=[0, 2, 1, 3], + ) + ) + inits.append(_init(f"{ap}/out_rs", np.array([0, 0, hidden_size], np.int64))) + nodes.append( + helper.make_node( + "Reshape", + [f"{ap}/Transpose_4_output_0", f"{ap}/out_rs"], + [f"{ap}/Reshape_7_output_0"], + name=f"{ap}/Reshape_7", + ) + ) + + # --- o_proj + residual + lm_head --- + nodes.append( + helper.make_node( + "MatMul", + [f"{ap}/Reshape_7_output_0", "model.layers.0.self_attn.o_proj.weight"], + [f"{ap}/o_proj/MatMul_output_0"], + name=f"{ap}/o_proj/MatMul", + ) + ) + hidden_out = "/model/embed_tokens/Gather_output_0" + nodes.append( + helper.make_node( + "Add", + [hidden_out, f"{ap}/o_proj/MatMul_output_0"], + ["/model/layers.0/residual_output_0"], + name="/model/layers.0/residual_add", + ) + ) + inits.append(_init("lm_head.weight", _fp16(hidden_size, VOCAB_SIZE))) + nodes.append( + helper.make_node( + "MatMul", ["/model/layers.0/residual_output_0", "lm_head.weight"], ["logits"], name="/lm_head/MatMul" + ) + ) + + graph = helper.make_graph(nodes, "llm_attn", graph_inputs, graph_outputs, initializer=inits) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + return model + + +def _build_quantized_model() -> onnx.ModelProto: + """Create a quantized model with DequantizeLinear(8x16) feeding into MatMul.""" + qweight = numpy_helper.from_array(np.random.randint(-128, 127, (8, 16), dtype=np.int8), "qweight") + scale = numpy_helper.from_array(np.array([0.01], dtype=np.float32), "scale") + x_input = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8]) + y_output = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 16]) + nodes = [ + helper.make_node("DequantizeLinear", ["qweight", "scale"], ["dq_out"], name="dql_0"), + helper.make_node("MatMul", ["X", "dq_out"], ["Y"], name="matmul_0"), + ] + graph = helper.make_graph(nodes, "quant_model", [x_input], [y_output], initializer=[qweight, scale]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + return model + + +# --------------------------------------------------------------------------- +# Tests -- call the REAL surgery through Olive pass +# --------------------------------------------------------------------------- + + +def test_replace_gqa(tmp_path): + """Run real replace-gqa surgery through Olive pass, verify output structure.""" + hidden, heads, kv, hd = _get_config() + before_proto = _build_attention_model(hidden, heads, kv, hd) + + model_path = str(tmp_path / "model.onnx") + onnx.save(before_proto, model_path) + ov_model = ONNXModelHandler(model_path=model_path) + + config = { + "surgery_type": "replace-gqa", + "surgery_params": { + "hf_model_id": MODEL_ID, + "max_seq_len": 128, + "io_dtype": "float16", + }, + } + + p = create_pass_from_dict(NVModelOptGraphSurgery, config, disable_search=True) + result = p.run(ov_model, str(tmp_path / "output_gqa")) + result_proto = result.load_model() + + op_types = [n.op_type for n in result_proto.graph.node] + node_names = [n.name for n in result_proto.graph.node] + input_names = [i.name for i in result_proto.graph.input] + output_names = [o.name for o in result_proto.graph.output] + + assert "GroupQueryAttention" in op_types, f"Expected GQA node, got ops: {op_types}" + assert any("o_proj" in n for n in node_names), "o_proj MatMul should be preserved" + assert any("past_key_values" in n for n in input_names), f"Expected KV cache input, got: {input_names}" + assert any("present" in n for n in output_names), f"Expected present output, got: {output_names}" + assert "Softmax" not in op_types, "Old Softmax should be removed" + assert any("qkv_proj" in n for n in node_names), "Expected fused QKV MatMul" + assert "ReduceSum" in op_types, "Attention mask subgraph should be present" + + gqa_node = next(n for n in result_proto.graph.node if n.op_type == "GroupQueryAttention") + attrs = {a.name: (a.i if a.type == 2 else a.f) for a in gqa_node.attribute} + assert attrs["num_heads"] == heads + assert attrs["kv_num_heads"] == kv + assert attrs["do_rotary"] == 1 + + +def test_transpose_dq(tmp_path): + """Run real transpose-dq surgery through Olive pass, verify transposed weights.""" + before_proto = _build_quantized_model() + model_path = str(tmp_path / "model_quant.onnx") + onnx.save(before_proto, model_path) + ov_model = ONNXModelHandler(model_path=model_path) + + config = { + "surgery_type": "transpose-dq", + "surgery_params": {}, + } + + p = create_pass_from_dict(NVModelOptGraphSurgery, config, disable_search=True) + result = p.run(ov_model, str(tmp_path / "output_dq")) + result_proto = result.load_model() + + op_types = [n.op_type for n in result_proto.graph.node] + node_names = [n.name for n in result_proto.graph.node] + + assert "DequantizeLinear" in op_types, f"DequantizeLinear should still exist, got: {op_types}" + assert "Transpose" in op_types, f"Expected Transpose node, got: {op_types}" + assert any("transpose_back" in n for n in node_names), f"Expected *_transpose_back, got: {node_names}" + assert "MatMul" in op_types, "MatMul should still exist" + + for init in result_proto.graph.initializer: + if "transposed" in init.name: + assert list(init.dims) == [16, 8], f"Expected transposed shape [16,8], got {list(init.dims)}" From 4adc8bbbe3dae90dca88bb209b57b3987ec4c216 Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Tue, 31 Mar 2026 14:42:59 +0530 Subject: [PATCH 2/3] Added ModelOpt graph surgery pass in Olive Signed-off-by: Hrishith Thadicherla --- olive/passes/onnx/nvmo_graph_surgery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/passes/onnx/nvmo_graph_surgery.py b/olive/passes/onnx/nvmo_graph_surgery.py index 646c8dd84..bb05e8d0e 100644 --- a/olive/passes/onnx/nvmo_graph_surgery.py +++ b/olive/passes/onnx/nvmo_graph_surgery.py @@ -113,7 +113,7 @@ def _run_for_config( # Call ModelOpt's unified dispatcher result = run_graph_surgery( surgery_name=surgery_type, - input_path=temp_input_path, + model_path=temp_input_path, output_path=temp_output_path, **surgery_params, ) From 51a22c647fe57a2178d687b93a8379bccbc936cc Mon Sep 17 00:00:00 2001 From: Hrishith Thadicherla Date: Tue, 31 Mar 2026 15:31:05 +0530 Subject: [PATCH 3/3] Updated documentation with usage Signed-off-by: Hrishith Thadicherla --- docs/source/features/onnx-transformations.md | 75 ++++++++++++++++++++ docs/source/reference/pass.rst | 6 ++ 2 files changed, 81 insertions(+) diff --git a/docs/source/features/onnx-transformations.md b/docs/source/features/onnx-transformations.md index 9dfcd1c67..6eee77681 100644 --- a/docs/source/features/onnx-transformations.md +++ b/docs/source/features/onnx-transformations.md @@ -2100,6 +2100,81 @@ Two cases are supported: ``` +## NVIDIA ModelOpt Graph Surgeries + +`NVModelOptGraphSurgery` provides access to graph-level transformations from [NVIDIA ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer). These surgeries are designed for optimizing LLM and encoder-decoder ONNX models for deployment with ONNX Runtime and TensorRT. + +Available surgery types: + +| Surgery | Description | +|---------|-------------| +| `replace-gqa` | Replace standard multi-head attention with ORT's GroupQueryAttention (GQA) operator | +| `transpose-dq` | Transpose DequantizeLinear weights for column-major storage optimization | +| `add-cross-kv` | Add cross-attention KV cache outputs to Whisper encoder models | +| `convert-bf16` | Convert FP16 model initializers and I/O to BF16 | + +Please refer to [NVModelOptGraphSurgery](../reference/pass.rst#nvmodelopt_graph_surgery) for more details about the pass and its config parameters. + +### Replace Attention with GQA + +Replaces the native multi-head attention subgraph (Q/K/V projections, RoPE, KV cache, scaled dot-product attention) with ORT's fused `GroupQueryAttention` operator. Supports models exported via Optimum or similar tools. + +```json +{ + "type": "NVModelOptGraphSurgery", + "surgery_type": "replace-gqa", + "surgery_params": { + "hf_model_id": "meta-llama/Llama-2-7b-hf", + "max_seq_len": 4096, + "io_dtype": "float16" + } +} +``` + +Key `surgery_params`: + +- `hf_model_id`: HuggingFace model ID (used to compute RoPE caches and read model config). +- `max_seq_len`: Maximum sequence length for the KV cache. +- `io_dtype`: I/O data type. Use `"float16"` or `"bfloat16"`. If `"bfloat16"` is specified and the model has FP16 initializers, they are automatically converted to BF16. + +### Transpose DequantizeLinear Weights + +Transposes quantized weight initializers feeding `DequantizeLinear` nodes and inserts a `Transpose` node before `MatMul`. This enables column-major weight storage for improved memory access patterns. + +```json +{ + "type": "NVModelOptGraphSurgery", + "surgery_type": "transpose-dq", + "surgery_params": {} +} +``` + +### Add Cross-Attention KV to Encoder + +Adds cross-attention key/value cache outputs to a Whisper encoder model, making it compatible with ONNX Runtime GenAI pipelines. + +```json +{ + "type": "NVModelOptGraphSurgery", + "surgery_type": "add-cross-kv", + "surgery_params": { + "hf_model_id": "openai/whisper-large-v3-turbo" + } +} +``` + +### Convert FP16 to BF16 + +Standalone precision conversion from FP16 to BF16 for all model initializers and I/O tensors. + +```json +{ + "type": "NVModelOptGraphSurgery", + "surgery_type": "convert-bf16", + "surgery_params": {} +} +``` + ## ORT Performance Tuning ONNX Runtime provides high performance across a range of hardware options through its Execution Providers interface for different execution diff --git a/docs/source/reference/pass.rst b/docs/source/reference/pass.rst index 4ce0e2c1b..c1cdb1565 100644 --- a/docs/source/reference/pass.rst +++ b/docs/source/reference/pass.rst @@ -116,6 +116,12 @@ GraphSurgeries -------------------- .. autoconfigclass:: olive.passes.GraphSurgeries +.. _nvmodelopt_graph_surgery: + +NVModelOptGraphSurgery +---------------------- +.. autoconfigclass:: olive.passes.NVModelOptGraphSurgery + .. _matmulnbits_to_qdq: MatMulNBitsToQDQ