diff --git a/olive/olive_config.json b/olive/olive_config.json index 2748c39101..4168e0b2d3 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -626,6 +626,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true + }, + "VitisGenerateModelSD": { + "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", + "supported_providers": [ "CPUExecutionProvider" ], + "supported_accelerators": [ "cpu" ], + "supported_precisions": [ "int8" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "run_on_target": true } }, "extra_dependencies": { diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py new file mode 100644 index 0000000000..9fc5059aec --- /dev/null +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- + +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation. + +Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, +then this pass runs generate_sd_model to generate NPU-ready models. +""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +from model_generate import generate_model +from model_generate.recipes import get_supported_sd_model_types + +from olive.model import ONNXModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) +SUPPORTED_SD_MODEL_TYPES = get_supported_sd_model_types() + + +class VitisGenerateModelSD(Pass): + """Generate Vitis NPU-ready SD submodel from ONNX input. + + Use OnnxConversion to produce ONNX input model. + Optional resolutions to generate NPU-ready models. Default is [512x512]. + """ + + @classmethod + def _default_config(cls, accelerator_spec): + return { + "model_type": PassConfigParam( + type_=str, + required=True, + description=f"SD submodel type, must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}.", + ), + "resolutions": PassConfigParam( + type_=list[str], + default_value=["512x512"], + required=False, + description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].", + ), + } + + @staticmethod + def _validate_model_type(model_type: str) -> None: + if model_type not in SUPPORTED_SD_MODEL_TYPES: + raise ValueError(f"model_type must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}, got {model_type!r}") + + def _run_for_config( + self, + model: ONNXModelHandler, + config: BasePassConfig, + output_model_path: str, + ) -> ONNXModelHandler: + if not isinstance(model, ONNXModelHandler): + raise TypeError( + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" + ) + model_type = config.model_type + self._validate_model_type(model_type) + + output_dir = Path(output_model_path) + if output_dir.suffix == ".onnx": + output_dir = output_dir.parent + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info( + "[VitisGenerateModelSD] output_dir=%s, model_type=%s", + output_dir, + model_type, + ) + + onnx_input_path = self._resolve_onnx_input_path(model) + logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) + + resolutions = getattr(config, "resolutions", None) + if resolutions: + logger.info( + "[VitisGenerateModelSD] Using resolutions: %s", + resolutions, + ) + + generate_model( + mode="sd", + input_model=str(onnx_input_path), + output_dir=str(output_dir), + extra_options={"model_type": model_type, "resolutions": ",".join(resolutions)}, + ) + + self._ensure_model_onnx(output_dir) + + return ONNXModelHandler( + model_path=str(output_dir), + onnx_file_name="model.onnx", + ) + + def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: + p = Path(model.model_path) + if p.is_file(): + return p + if p.is_dir(): + name = getattr(model, "onnx_file_name", None) + if name: + f = p / name + if f.exists(): + return f + onnx_files = list(p.glob("*.onnx")) + if onnx_files: + return onnx_files[0] + raise FileNotFoundError(f"No .onnx file found under {p}") + raise FileNotFoundError(f"Model path does not exist: {p}") + + def _ensure_model_onnx(self, output_dir: Path) -> None: + """Copy actual generate_sd_model output to output_dir/model.onnx if needed.""" + model_onnx = output_dir / "model.onnx" + if model_onnx.exists(): + return + optimized = output_dir / "optimized.onnx" + dd_replaced = output_dir / "dd" / "replaced.onnx" + if dd_replaced.exists(): + shutil.copy2(dd_replaced, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx") + elif optimized.exists(): + shutil.copy2(optimized, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx") + else: + logger.warning( + "[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under %s", + output_dir, + )