Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,15 @@
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"run_on_target": true
},
"VitisGenerateModelSD": {
"module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD",
"supported_providers": [ "CPUExecutionProvider" ],
"supported_accelerators": [ "cpu" ],
"supported_precisions": [ "int8" ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"run_on_target": true
}
},
"extra_dependencies": {
Expand Down
138 changes: 138 additions & 0 deletions olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -------------------------------------------------------------------------
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
# -------------------------------------------------------------------------

"""Olive Pass for Vitis NPU Stable Diffusion submodel generation.

Accepts ONNX input only; run OnnxConversion to produce ONNX input model first,
then this pass runs generate_sd_model to generate NPU-ready models.
"""

from __future__ import annotations

import logging
import shutil
from pathlib import Path

from model_generate import generate_model
from model_generate.recipes import get_supported_sd_model_types

from olive.model import ONNXModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam

logger = logging.getLogger(__name__)
SUPPORTED_SD_MODEL_TYPES = get_supported_sd_model_types()


class VitisGenerateModelSD(Pass):
"""Generate Vitis NPU-ready SD submodel from ONNX input.

Use OnnxConversion to produce ONNX input model.
Optional resolutions to generate NPU-ready models. Default is [512x512].
"""

@classmethod
def _default_config(cls, accelerator_spec):
return {
"model_type": PassConfigParam(
type_=str,
required=True,
description=f"SD submodel type, must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}.",
),
"resolutions": PassConfigParam(
type_=list[str],
default_value=["512x512"],
required=False,
description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].",
),
}

@staticmethod
def _validate_model_type(model_type: str) -> None:
if model_type not in SUPPORTED_SD_MODEL_TYPES:
raise ValueError(f"model_type must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}, got {model_type!r}")

def _run_for_config(
self,
model: ONNXModelHandler,
config: BasePassConfig,
output_model_path: str,
) -> ONNXModelHandler:
if not isinstance(model, ONNXModelHandler):
raise TypeError(
f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}"
)
model_type = config.model_type
self._validate_model_type(model_type)

output_dir = Path(output_model_path)
if output_dir.suffix == ".onnx":
output_dir = output_dir.parent
output_dir.mkdir(parents=True, exist_ok=True)

logger.info(
"[VitisGenerateModelSD] output_dir=%s, model_type=%s",
output_dir,
model_type,
)

onnx_input_path = self._resolve_onnx_input_path(model)
logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path)

resolutions = getattr(config, "resolutions", None)
if resolutions:
logger.info(
"[VitisGenerateModelSD] Using resolutions: %s",
resolutions,
)

generate_model(
mode="sd",
input_model=str(onnx_input_path),
output_dir=str(output_dir),
extra_options={"model_type": model_type, "resolutions": ",".join(resolutions)},
)

self._ensure_model_onnx(output_dir)

return ONNXModelHandler(
model_path=str(output_dir),
onnx_file_name="model.onnx",
)

def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path:
p = Path(model.model_path)
if p.is_file():
return p
if p.is_dir():
name = getattr(model, "onnx_file_name", None)
if name:
f = p / name
if f.exists():
return f
onnx_files = list(p.glob("*.onnx"))
if onnx_files:
return onnx_files[0]
raise FileNotFoundError(f"No .onnx file found under {p}")
raise FileNotFoundError(f"Model path does not exist: {p}")

def _ensure_model_onnx(self, output_dir: Path) -> None:
"""Copy actual generate_sd_model output to output_dir/model.onnx if needed."""
model_onnx = output_dir / "model.onnx"
if model_onnx.exists():
return
optimized = output_dir / "optimized.onnx"
dd_replaced = output_dir / "dd" / "replaced.onnx"
if dd_replaced.exists():
shutil.copy2(dd_replaced, model_onnx)
logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx")
elif optimized.exists():
shutil.copy2(optimized, model_onnx)
logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx")
else:
logger.warning(
"[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under %s",
output_dir,
)
Loading