From a54dbb6994e4f70dbb1c678b97777062f107f914 Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 16 Mar 2026 01:51:03 -0500 Subject: [PATCH 01/10] [feat] add vitis_generate_model_sd.py. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py new file mode 100644 index 0000000000..c47ab7ec0a --- /dev/null +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -0,0 +1,169 @@ +# +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# + +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). +Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, +then this pass runs generate_sd_model for preprocess + partition. +""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path +from typing import Optional + +from olive.model import ONNXModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +def _get_sd_registry(): + """Import registry from npu_model_gen to keep model_type choices in sync.""" + from model_generate import _SD_CONFIG_REGISTRY + return _SD_CONFIG_REGISTRY + + +def _build_fixed_shapes(dim_param: Optional[list], dim_value: Optional[list]) -> Optional[list[str]]: + """Build --fixed-shapes style list (e.g. ['batch=1', 'height=64']) from dim_param and dim_value.""" + if not dim_param or not dim_value: + return None + if len(dim_param) != len(dim_value): + raise ValueError("dim_param and dim_value must have the same length.") + return [f"{p}={v}" for p, v in zip(dim_param, dim_value)] + + +class VitisGenerateModelSD(Pass): + """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. + Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. + Optional dim_param / dim_value override the default fixed shapes used in preprocess (like DynamicToFixedShape). + """ + + @classmethod + def _default_config(cls, accelerator_spec): + registry = _get_sd_registry() + return { + "model_type": PassConfigParam( + type_=str, + required=True, + description="SD submodel type, must be a key from SD config registry (e.g. sd_unet, sd_vae_decoder, sd_vae_encoder).", + ), + "fixed_shapes_dim_param": PassConfigParam( + type_=list, + default_value=None, + required=False, + description=( + "Symbolic dimension names for fixed shapes (e.g. ['batch','channels','height','width']). " + ), + ), + "fixed_shapes_dim_value": PassConfigParam( + type_=list, + default_value=None, + required=False, + description=( + "Defines the values for dimensions listed in fixed_shapes_dim_param (e.g., [1, 4, 64, 64]). " + "Use 'x' to preserve a dynamic dimension (e.g., [1, 4, 'x', 'x']). " + "The length must match fixed_shapes_dim_param if specified." + ), + ), + } + + @staticmethod + def _validate_model_type(model_type: str) -> None: + registry = _get_sd_registry() + if model_type not in registry: + raise ValueError( + f"model_type must be one of {list(registry.keys())}, got {model_type!r}" + ) + + def _run_for_config( + self, + model: ONNXModelHandler, + config: BasePassConfig, + output_model_path: str, + ) -> ONNXModelHandler: + if not isinstance(model, ONNXModelHandler): + raise TypeError( + "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " + f"Got {type(model).__name__}" + ) + model_type = config.model_type + self._validate_model_type(model_type) + + output_dir = Path(output_model_path) + if output_dir.suffix == ".onnx": + output_dir = output_dir.parent + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info( + "[VitisGenerateModelSD] output_dir=%s, model_type=%s", + output_dir, + model_type, + ) + + onnx_input_path = self._resolve_onnx_input_path(model) + logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) + + fixed_shapes = _build_fixed_shapes( + getattr(config, "fixed_shapes_dim_param", None), getattr(config, "fixed_shapes_dim_value", None) + ) + if fixed_shapes: + logger.info( + "[VitisGenerateModelSD] Overriding fixed shapes: %s", + fixed_shapes, + ) + + from model_generate import generate_sd_model + + generate_sd_model( + input_model=str(onnx_input_path), + output_dir=str(output_dir), + model_type=model_type, + fixed_shapes=fixed_shapes, + ) + + self._ensure_model_onnx(output_dir) + + return ONNXModelHandler( + model_path=str(output_dir), + onnx_file_name="model.onnx", + ) + + def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: + p = Path(model.model_path) + if p.is_file(): + return p + if p.is_dir(): + name = getattr(model, "onnx_file_name", None) + if name: + f = p / name + if f.exists(): + return f + onnx_files = list(p.glob("*.onnx")) + if onnx_files: + return onnx_files[0] + raise FileNotFoundError(f"No .onnx file found under {p}") + raise FileNotFoundError(f"Model path does not exist: {p}") + + def _ensure_model_onnx(self, output_dir: Path) -> None: + """Copy actual generate_sd_model output to output_dir/model.onnx if needed.""" + model_onnx = output_dir / "model.onnx" + if model_onnx.exists(): + return + optimized = output_dir / "optimized.onnx" + dd_replaced = output_dir / "dd" / "replaced.onnx" + if dd_replaced.exists(): + shutil.copy2(dd_replaced, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx") + elif optimized.exists(): + shutil.copy2(optimized, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx") + else: + logger.warning( + "[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under %s", + output_dir, + ) From 0a98c0dc12daf868f4029c531d483631d8b5adaa Mon Sep 17 00:00:00 2001 From: liujij Date: Tue, 17 Mar 2026 03:08:34 -0500 Subject: [PATCH 02/10] up olive_config.json. --- olive/olive_config.json | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/olive/olive_config.json b/olive/olive_config.json index 2748c39101..dcdf4af072 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -626,6 +626,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true + }, + "VitisGenerateModelSD": { + "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", + "supported_providers": [ "CPUExecutionProvider" ], + "supported_accelerators": [ "cpu" ], + "supported_precisions": [ "bf16", "bfp16" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "run_on_target": true } }, "extra_dependencies": { From 6f675d4f82d1796c1abf22cf6bbbeb4bc8c2fe38 Mon Sep 17 00:00:00 2001 From: liujij Date: Tue, 24 Mar 2026 04:54:05 -0500 Subject: [PATCH 03/10] update codes. --- olive/olive_config.json | 2 +- .../onnx/vitis_ai/vitis_generate_model_sd.py | 46 +++++-------------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/olive/olive_config.json b/olive/olive_config.json index dcdf4af072..4168e0b2d3 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -631,7 +631,7 @@ "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", "supported_providers": [ "CPUExecutionProvider" ], "supported_accelerators": [ "cpu" ], - "supported_precisions": [ "bf16", "bfp16" ], + "supported_precisions": [ "int8" ], "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index c47ab7ec0a..e5bd88cd11 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -27,20 +27,10 @@ def _get_sd_registry(): from model_generate import _SD_CONFIG_REGISTRY return _SD_CONFIG_REGISTRY - -def _build_fixed_shapes(dim_param: Optional[list], dim_value: Optional[list]) -> Optional[list[str]]: - """Build --fixed-shapes style list (e.g. ['batch=1', 'height=64']) from dim_param and dim_value.""" - if not dim_param or not dim_value: - return None - if len(dim_param) != len(dim_value): - raise ValueError("dim_param and dim_value must have the same length.") - return [f"{p}={v}" for p, v in zip(dim_param, dim_value)] - - class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. - Optional dim_param / dim_value override the default fixed shapes used in preprocess (like DynamicToFixedShape). + Optional resolutions override the default fixed shapes used in preprocess. Default is [512x512]. """ @classmethod @@ -50,25 +40,13 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description="SD submodel type, must be a key from SD config registry (e.g. sd_unet, sd_vae_decoder, sd_vae_encoder).", + description=f"SD submodel type, must be a key from SD config registry (e.g. {list(registry.keys())}).", ), - "fixed_shapes_dim_param": PassConfigParam( - type_=list, - default_value=None, + "resolutions": PassConfigParam( + type_=list[str], + default_value=["512x512"], required=False, - description=( - "Symbolic dimension names for fixed shapes (e.g. ['batch','channels','height','width']). " - ), - ), - "fixed_shapes_dim_value": PassConfigParam( - type_=list, - default_value=None, - required=False, - description=( - "Defines the values for dimensions listed in fixed_shapes_dim_param (e.g., [1, 4, 64, 64]). " - "Use 'x' to preserve a dynamic dimension (e.g., [1, 4, 'x', 'x']). " - "The length must match fixed_shapes_dim_param if specified." - ), + description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].", ), } @@ -108,13 +86,11 @@ def _run_for_config( onnx_input_path = self._resolve_onnx_input_path(model) logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) - fixed_shapes = _build_fixed_shapes( - getattr(config, "fixed_shapes_dim_param", None), getattr(config, "fixed_shapes_dim_value", None) - ) - if fixed_shapes: + resolutions = getattr(config, "resolutions", None) + if resolutions: logger.info( - "[VitisGenerateModelSD] Overriding fixed shapes: %s", - fixed_shapes, + "[VitisGenerateModelSD] Using resolutions: %s", + resolutions, ) from model_generate import generate_sd_model @@ -123,7 +99,7 @@ def _run_for_config( input_model=str(onnx_input_path), output_dir=str(output_dir), model_type=model_type, - fixed_shapes=fixed_shapes, + resolutions=resolutions, ) self._ensure_model_onnx(output_dir) From 342c43da2c8cc9ec2afd44c1bbf0649bdfb26dc8 Mon Sep 17 00:00:00 2001 From: liujij Date: Wed, 25 Mar 2026 04:24:05 -0500 Subject: [PATCH 04/10] up. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index e5bd88cd11..1d17806a4b 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -24,8 +24,8 @@ def _get_sd_registry(): """Import registry from npu_model_gen to keep model_type choices in sync.""" - from model_generate import _SD_CONFIG_REGISTRY - return _SD_CONFIG_REGISTRY + import model_generate + return model_generate.SUPPORTED_SD_MODEL_TYPES class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. @@ -40,7 +40,7 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be a key from SD config registry (e.g. {list(registry.keys())}).", + description=f"SD submodel type, must be one of {', '.join(registry)}.", ), "resolutions": PassConfigParam( type_=list[str], From f8b9e49d076487171185a5831a3d2fb2f77d8fdf Mon Sep 17 00:00:00 2001 From: liujij Date: Wed, 25 Mar 2026 04:35:44 -0500 Subject: [PATCH 05/10] lint. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 1d17806a4b..22112e321a 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -4,6 +4,7 @@ # """Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). + Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, then this pass runs generate_sd_model for preprocess + partition. """ @@ -13,7 +14,6 @@ import logging import shutil from pathlib import Path -from typing import Optional from olive.model import ONNXModelHandler from olive.passes import Pass @@ -25,12 +25,15 @@ def _get_sd_registry(): """Import registry from npu_model_gen to keep model_type choices in sync.""" import model_generate + return model_generate.SUPPORTED_SD_MODEL_TYPES + class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. - Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. - Optional resolutions override the default fixed shapes used in preprocess. Default is [512x512]. + + Use OnnxConversion to produce ONNX input model. + Optional resolutions to generate NPU-ready models. Default is [512x512]. """ @classmethod @@ -54,9 +57,7 @@ def _default_config(cls, accelerator_spec): def _validate_model_type(model_type: str) -> None: registry = _get_sd_registry() if model_type not in registry: - raise ValueError( - f"model_type must be one of {list(registry.keys())}, got {model_type!r}" - ) + raise ValueError(f"model_type must be one of {list(registry.keys())}, got {model_type!r}") def _run_for_config( self, @@ -66,8 +67,7 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " - f"Got {type(model).__name__}" + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From caa483cae0edc6f8a4f44370779493d758179810 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 03:12:06 -0500 Subject: [PATCH 06/10] up. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 22112e321a..9fe864171e 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -1,12 +1,12 @@ # -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT # """Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). -Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, -then this pass runs generate_sd_model for preprocess + partition. +Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, +then this pass runs generate_sd_model to generate NPU-ready models. """ from __future__ import annotations @@ -57,7 +57,7 @@ def _default_config(cls, accelerator_spec): def _validate_model_type(model_type: str) -> None: registry = _get_sd_registry() if model_type not in registry: - raise ValueError(f"model_type must be one of {list(registry.keys())}, got {model_type!r}") + raise ValueError(f"model_type must be one of {', '.join(registry)}, got {model_type!r}") def _run_for_config( self, @@ -67,7 +67,8 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" + "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " + f"Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From e3a73fb8f8cffd6f5755f254529657324bf3d306 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 03:17:15 -0500 Subject: [PATCH 07/10] lint. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 9fe864171e..ca9a3216f8 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -38,12 +38,11 @@ class VitisGenerateModelSD(Pass): @classmethod def _default_config(cls, accelerator_spec): - registry = _get_sd_registry() return { "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be one of {', '.join(registry)}.", + description=f"SD submodel type, must be one of {', '.join(_get_sd_registry())}.", ), "resolutions": PassConfigParam( type_=list[str], @@ -67,8 +66,7 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " - f"Got {type(model).__name__}" + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From e0966e556b331f4d058282e6e9febee9e7fb9c75 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 22:52:52 -0500 Subject: [PATCH 08/10] ruff. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index ca9a3216f8..ab1a7517a6 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT # -"""Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation. Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, then this pass runs generate_sd_model to generate NPU-ready models. @@ -15,6 +15,8 @@ import shutil from pathlib import Path +from model_generate import SUPPORTED_SD_MODEL_TYPES, generate_sd_model + from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam @@ -22,15 +24,8 @@ logger = logging.getLogger(__name__) -def _get_sd_registry(): - """Import registry from npu_model_gen to keep model_type choices in sync.""" - import model_generate - - return model_generate.SUPPORTED_SD_MODEL_TYPES - - class VitisGenerateModelSD(Pass): - """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. + """Generate Vitis NPU-ready SD submodel from ONNX input. Use OnnxConversion to produce ONNX input model. Optional resolutions to generate NPU-ready models. Default is [512x512]. @@ -42,7 +37,7 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be one of {', '.join(_get_sd_registry())}.", + description=f"SD submodel type, must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}.", ), "resolutions": PassConfigParam( type_=list[str], @@ -54,9 +49,8 @@ def _default_config(cls, accelerator_spec): @staticmethod def _validate_model_type(model_type: str) -> None: - registry = _get_sd_registry() - if model_type not in registry: - raise ValueError(f"model_type must be one of {', '.join(registry)}, got {model_type!r}") + if model_type not in SUPPORTED_SD_MODEL_TYPES: + raise ValueError(f"model_type must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}, got {model_type!r}") def _run_for_config( self, @@ -92,8 +86,6 @@ def _run_for_config( resolutions, ) - from model_generate import generate_sd_model - generate_sd_model( input_model=str(onnx_input_path), output_dir=str(output_dir), From c168335a3e83b9a2bbe737fbda6d8ef90713d091 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 23:11:23 -0500 Subject: [PATCH 09/10] ruff. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index ab1a7517a6..4b4773fb4a 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -1,7 +1,7 @@ -# +# ------------------------------------------------------------------------- # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -# +# ------------------------------------------------------------------------- """Olive Pass for Vitis NPU Stable Diffusion submodel generation. From 70cf7d2a168647ec199111c337622908b6946674 Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 30 Mar 2026 04:28:05 -0500 Subject: [PATCH 10/10] update olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py with new model_generate. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 4b4773fb4a..9fc5059aec 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -15,13 +15,15 @@ import shutil from pathlib import Path -from model_generate import SUPPORTED_SD_MODEL_TYPES, generate_sd_model +from model_generate import generate_model +from model_generate.recipes import get_supported_sd_model_types from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) +SUPPORTED_SD_MODEL_TYPES = get_supported_sd_model_types() class VitisGenerateModelSD(Pass): @@ -86,11 +88,11 @@ def _run_for_config( resolutions, ) - generate_sd_model( + generate_model( + mode="sd", input_model=str(onnx_input_path), output_dir=str(output_dir), - model_type=model_type, - resolutions=resolutions, + extra_options={"model_type": model_type, "resolutions": ",".join(resolutions)}, ) self._ensure_model_onnx(output_dir)