From 445dcff7ba1f9b3f5c55439f0528e51ca14bd323 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:57:04 -0700 Subject: [PATCH 1/5] Add ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON_FILE env var (#27945) ### Description Add support for specifying dynamic plugin EP configuration via a JSON file path in the ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON_FILE environment variable. This is mutually exclusive with the specifying inline JSON using the existing ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON environment variable. ### Motivation and Context Allow more flexibility in specifying configuration. It may be impractical to put everything in an environment variable. --- onnxruntime/test/unittest_main/test_main.cc | 31 ++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 4afe9dc51b9e5..d151955c7549c 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -3,12 +3,14 @@ #include #include +#include +#include +#include #include #include #include #include #ifdef _WIN32 -#include #include #endif @@ -51,6 +53,9 @@ constexpr const char* kLogLevel = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; // Specify dynamic plugin EP configuration JSON. // Refer to `onnxruntime::test::dynamic_plugin_ep_infra::ParseInitializationConfig()` for more information. constexpr const char* kDynamicPluginEpConfigJson = "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON"; +// Specify a file path from which to read dynamic plugin EP configuration JSON. +// Mutually exclusive with kDynamicPluginEpConfigJson. +constexpr const char* kDynamicPluginEpConfigJsonFile = "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON_FILE"; #endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) } // namespace env_var_names @@ -79,9 +84,27 @@ extern "C" void ortenv_setup() { #if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) { namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; - if (auto dynamic_plugin_ep_config_json = onnxruntime::ParseEnvironmentVariable( - env_var_names::kDynamicPluginEpConfigJson); - dynamic_plugin_ep_config_json.has_value()) { + + auto dynamic_plugin_ep_config_json = onnxruntime::ParseEnvironmentVariable( + env_var_names::kDynamicPluginEpConfigJson); + auto dynamic_plugin_ep_config_json_file = onnxruntime::ParseEnvironmentVariable( + env_var_names::kDynamicPluginEpConfigJsonFile); + + ORT_ENFORCE(!dynamic_plugin_ep_config_json.has_value() || !dynamic_plugin_ep_config_json_file.has_value(), + "Only one of ", env_var_names::kDynamicPluginEpConfigJson, + " and ", env_var_names::kDynamicPluginEpConfigJsonFile, + " should be set, not both."); + + if (dynamic_plugin_ep_config_json_file.has_value()) { + const auto& config_file_path = *dynamic_plugin_ep_config_json_file; + std::cout << "Reading dynamic plugin EP configuration from file: " << config_file_path << "\n"; + std::ifstream config_file{config_file_path}; + ORT_ENFORCE(config_file, "Failed to open dynamic plugin EP configuration file: ", config_file_path); + dynamic_plugin_ep_config_json.emplace( + std::istreambuf_iterator{config_file}, std::istreambuf_iterator{}); + } + + if (dynamic_plugin_ep_config_json.has_value()) { std::cout << "Initializing dynamic plugin EP infrastructure with configuration:\n" << *dynamic_plugin_ep_config_json << "\n"; dynamic_plugin_ep_infra::InitializationConfig config{}; From a18e5b982f6c4b27bbc5ab3dcf55e23099d4d653 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 2 Apr 2026 19:07:05 -0700 Subject: [PATCH 2/5] Cleanup for op schema API tests for plugin EPs (#27921) ### Description Address some good leftover comments from PR that added EP APIs to retrieve operator schemas: https://github.com/microsoft/onnxruntime/pull/27713 ### Motivation and Context Clean up as promised --- .../core/session/onnxruntime_cxx_api.h | 17 +++++------------ .../test/framework/ep_plugin_provider_test.cc | 17 +++++------------ 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a938688fcfd5a..e457a2a57065e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3613,12 +3613,6 @@ struct KernelRegistry : detail::Base { }; namespace detail { -/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. - * - * Holds a single type constraint from an operator schema, providing access to - * the constraint's name, allowed data types, and associated input/output indices. - * This is a non-owning view — the lifetime is tied to the parent OrtOpSchema. - */ template struct OpSchemaTypeConstraintImpl : Base { using B = Base; @@ -3639,15 +3633,11 @@ struct OpSchemaTypeConstraintImpl : Base { } // namespace detail /// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. +/// Holds a single type constraint from an operator schema, providing access to +/// the constraint's name, allowed data types, and associated input/output indices. using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl>; namespace detail { -/** \brief Owning wrapper around an `OrtOpSchema*`. - * - * Provides access to operator schema metadata such as version, input/output names, - * and type constraints. The underlying OrtOpSchema is owned by this wrapper and - * released automatically on destruction. - */ template struct OpSchemaImpl : Base { using B = Base; @@ -3685,6 +3675,9 @@ struct OpSchemaImpl : Base { } // namespace detail /// Owning wrapper around an `OrtOpSchema*`. +/// Provides access to operator schema metadata such as version, input/output names, +/// and type constraints. The underlying OrtOpSchema is owned by this wrapper and +/// released automatically on destruction. using OpSchema = detail::OpSchemaImpl; /// \brief Get an operator schema from the global schema registry. diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index da958ba6fc970..9640d94aebe58 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -876,11 +876,8 @@ TEST(OpSchemaTypeConstraintTest, Add_SingleConstraint) { // T should allow tensor(float) and tensor(double) among others auto allowed_types = tc.GetAllowedTypes(); EXPECT_GT(allowed_types.size(), 1u); - auto has_type = [&](const char* t) { - return std::find(allowed_types.begin(), allowed_types.end(), t) != allowed_types.end(); - }; - EXPECT_TRUE(has_type("tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type("tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // Both inputs use T auto input_indices = tc.GetInputIndices(); @@ -921,22 +918,18 @@ TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) { ASSERT_NE(t_ptr, nullptr) << "Expected to find type constraint 'T'"; ASSERT_NE(t1_ptr, nullptr) << "Expected to find type constraint 'T1'"; - auto has_type = [](gsl::span types, const char* t) { - return std::find(types.begin(), types.end(), t) != types.end(); - }; - // T should include tensor(float) and tensor(double) auto t_types = t_tc.GetAllowedTypes(); EXPECT_GT(t_types.size(), 0u); - EXPECT_TRUE(has_type(t_types, "tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type(t_types, "tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // T1 should include tensor(int32) (sequence_lens is int32) auto t1_types = t1_tc.GetAllowedTypes(); EXPECT_GT(t1_types.size(), 0u); // T1 is for sequence_lens which is int32 - EXPECT_TRUE(has_type(t1_types, "tensor(int32)")) << "Expected T1 to allow tensor(int32)"; + EXPECT_THAT(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)"; // T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7) auto t_inputs = t_tc.GetInputIndices(); From 5b2c0da33d3d86fbf6e2aafaf5274ccee21b6b3b Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 3 Apr 2026 10:28:49 -0700 Subject: [PATCH 3/5] [WebGPU EP] Support Conv3D (#27917) ### Description This pull request adds support for Conv3D operations to the WebGPU execution provider in ONNX Runtime. The main changes include implementing a new naive Conv3D shader, updating the convolution logic to handle 3D convolutions, and enabling relevant tests for Conv3D on WebGPU. Grouped Conv3D is not yet supported. **Conv3D WebGPU support:** * Added a new `Conv3DNaiveProgram` class (`conv3d_naive.h`, `conv3d_naive.cc`) that implements a per-element Conv3D shader for WebGPU, supporting both "channels last" and "channels first" layouts, with optional bias and activation. * Updated the main convolution logic in `conv.cc` to detect 5D tensors (Conv3D), construct the appropriate shader program, and pass spatial/stride/dilation parameters as uniforms. Grouped Conv3D is explicitly disallowed for now. * Included the new `conv3d_naive.h` header in the main convolution implementation. **Test coverage:** * Enabled Conv3D tests for the WebGPU provider by removing it from the excluded execution providers in several Conv3D test cases (`conv_op_test.cc`). * Added a note to the Conv3D fp16 test indicating that enabling it for WebGPU will require additional infrastructure to conditionally skip based on device capabilities. ### Motivation and Context Support additional cases in WebGPU EP Conv kernel. --- onnxruntime/core/providers/webgpu/nn/conv.cc | 39 +++- .../core/providers/webgpu/nn/conv3d_naive.cc | 174 ++++++++++++++++++ .../core/providers/webgpu/nn/conv3d_naive.h | 34 ++++ .../test/providers/cpu/nn/conv_fp16_test.cc | 2 + .../test/providers/cpu/nn/conv_op_test.cc | 6 +- 5 files changed, 250 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv3d_naive.h diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 697428e1ce140..c2a8896b84a7e 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/conv2d_mm.h" +#include "core/providers/webgpu/nn/conv3d_naive.h" #include "core/providers/webgpu/nn/im2col_matmul.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -80,8 +81,42 @@ Status Conv::ComputeInternal(ComputeContext& context std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); auto rank = input_shape.NumDimensions(); const InlinedVector perm = {2, 3, 1, 0}; - if (rank > 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported."); + if (rank > 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d, Conv2d, and Conv3d are supported."); + } else if (rank == 5) { + // Conv3D - use naive per-element shader (matching JS implementation) + if (conv_attrs_.group != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Conv3D does not support grouped convolution (group=", conv_attrs_.group, ")."); + } + const auto output_size = static_cast(output_shape.Size()); + const auto kernel_depth = static_cast(kernel_shape[2]); + const auto kernel_height = static_cast(kernel_shape[3]); + const auto kernel_width = static_cast(kernel_shape[4]); + // pads: head padding values for each spatial dim (front, top, left) + std::vector pads_3d{pads[0], pads[1], pads[2]}; + // Extract spatial dims and channels for explicit uniforms + const auto x_depth = static_cast(input_shape[is_channels_last ? 1 : 2]); + const auto x_height = static_cast(input_shape[is_channels_last ? 2 : 3]); + const auto x_width = static_cast(input_shape[is_channels_last ? 3 : 4]); + const auto x_channels = static_cast(input_shape[is_channels_last ? 4 : 1]); + Conv3DNaiveProgram program(activation_, has_bias, is_channels_last); + program.CacheHint(activation_.ToString(), std::to_string(is_channels_last)) + .AddInput({input, ProgramTensorMetadataDependency::TypeAndRank, input_shape, 1}) + .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, 1}) + .AddUniformVariables({{output_size}, + {std::vector{kernel_depth, kernel_height, kernel_width}}, + {pads_3d}, + {strides}, + {dilations}, + {std::vector{x_depth, x_height, x_width}}, + {x_channels}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, bias->Shape(), 1}); + } + return context.RunProgram(program); } else if (rank == 4) { // Conv2D } else if (rank == 3) { diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc new file mode 100644 index 0000000000000..76895e684eeab --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/nn/conv3d_naive.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" + +namespace onnxruntime { +namespace webgpu { + +Status Conv3DNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + + std::string apply_activation = GetActivationSnippet(activation_, "x_value_t", "x_element_t"); + + // Helper functions to access x and w by 5D indices + shader.AdditionalImplementation() + << "fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = x_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << x.GetByIndices("aIndices") << ";\n" + << "}\n" + << "fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = w_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << w.GetByIndices("aIndices") << ";\n" + << "}\n"; + + // Spatial dimensions and channels are passed as explicit uniforms + // to avoid rank-5 shape packing issues (array,2> vs vec4). + shader.MainFunctionBody() + << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch = output_indices[0];\n" + << "let d2 = " << output.IndicesGet("output_indices", is_channels_last_ ? "4" : "1") << ";\n" + << "let xFRCCorner = vec3(" << output.IndicesGet("output_indices", is_channels_last_ ? "1" : "2") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "2" : "3") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "3" : "4") << ") * uniforms.strides - uniforms.pads;\n" + << "let xFCorner = xFRCCorner.x;\n" + << "let xRCorner = xFRCCorner.y;\n" + << "let xCCorner = xFRCCorner.z;\n" + << "let xDepth = uniforms.x_spatial[0];\n" + << "let xHeight = uniforms.x_spatial[1];\n" + << "let xWidth = uniforms.x_spatial[2];\n" + << "let xChannels = uniforms.x_channels;\n" + << "let inputChannelsNearestVec4 = (xChannels / 4u) * 4u;\n" + << "let inputChannelsVec4Remainder = xChannels % 4u;\n" + << "\n" + << "var value = x_value_t(0);\n" + << "for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {\n" + << " let xF = xFCorner + wF * uniforms.dilations[0];\n" + << " if (xF >= xDepth) {\n" + << " continue;\n" + << " }\n" + << " for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) {\n" + << " let xR = xRCorner + wR * uniforms.dilations[1];\n" + << " if (xR >= xHeight) {\n" + << " continue;\n" + << " }\n" + << " for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) {\n" + << " let xC = xCCorner + wC * uniforms.dilations[2];\n" + << " if (xC >= xWidth) {\n" + << " continue;\n" + << " }\n" + << " for (var d1 = 0u; d1 < inputChannelsNearestVec4; d1 += 4u) {\n"; + + // vec4 dot product accumulation over input channels + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, xF, xR, xC, d1),\n" + << " getX(batch, xF, xR, xC, d1 + 1u),\n" + << " getX(batch, xF, xR, xC, d1 + 2u),\n" + << " getX(batch, xF, xR, xC, d1 + 3u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, d1, xF, xR, xC),\n" + << " getX(batch, d1 + 1u, xF, xR, xC),\n" + << " getX(batch, d1 + 2u, xF, xR, xC),\n" + << " getX(batch, d1 + 3u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec4(\n" + << " getW(d2, d1, wF, wR, wC),\n" + << " getW(d2, d1 + 1u, wF, wR, wC),\n" + << " getW(d2, d1 + 2u, wF, wR, wC),\n" + << " getW(d2, d1 + 3u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n"; + + // Handle remainder channels (1, 2, or 3) + shader.MainFunctionBody() + << " if (inputChannelsVec4Remainder == 1u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " value += getX(batch, xF, xR, xC, inputChannelsNearestVec4)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } else { + shader.MainFunctionBody() + << " value += getX(batch, inputChannelsNearestVec4, xF, xR, xC)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } + shader.MainFunctionBody() + << " } else if (inputChannelsVec4Remainder == 2u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec2(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " } else if (inputChannelsVec4Remainder == 3u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 2u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 2u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec3(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 2u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n" + << " }\n" + << " }\n" + << "}\n"; + + // Apply bias + if (has_bias_) { + const auto& b = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.MainFunctionBody() << "value = value + " << b.GetByIndices("d2") << ";\n"; + } + + // Apply activation + shader.MainFunctionBody() << apply_activation << "\n"; + + // Write output + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value"); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h new file mode 100644 index 0000000000000..25ae449a7d02c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Conv3DNaiveProgram final : public Program { + public: + Conv3DNaiveProgram(const Activation& activation, bool has_bias, bool is_channels_last) + : Program("Conv3DNaive"), activation_(activation), has_bias_(has_bias), is_channels_last_(is_channels_last) { + } + Status GenerateShaderCode(ShaderHelper& shader) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"filter_dims", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"x_spatial", ProgramUniformVariableDataType::Uint32}, + {"x_channels", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + bool has_bias_; + bool is_channels_last_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 6d6fedb3c9812..843d925ed6638 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -686,6 +686,8 @@ TEST(ConvFp16Test, Conv2D_AutoPad2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +// TODO: Enable Conv3D fp16 tests for WebGPU when the test infrastructure supports +// conditionally skipping based on device capabilities (e.g., wgpu::FeatureName::ShaderF16). TEST(ConvFp16Test, Conv3D_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 060b61c61532a..f8e93c19dc8d3 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -812,7 +812,7 @@ TEST(ConvTest, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {-0.43337246775627136f, -0.48385289311408997f, -0.30954962968826294f, @@ -849,7 +849,7 @@ TEST(ConvTest, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.010772407054901123f, -0.43806642293930054f, 0.455391526222229f, -0.28657248616218567f, @@ -892,7 +892,7 @@ TEST(ConvTest, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.46796226501464844f, -0.4613912105560303f, 0.33512794971466064f, -0.4010460674762726f, From a6592fc03bb3d341c63b967aae8f4523a97b9ec9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 3 Apr 2026 10:52:54 -0700 Subject: [PATCH 4/5] Cleanup: Consolidate `OpKernel::UseSharePrePackedBuffers_V2` and `OpKernel::UseSharePrePackedBuffers` (#27924) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Consolidate `OpKernel::UseSharedPrePackedBuffers` and `OpKernel::UseSharedPrePackedBuffers_V2` into a single virtual method, resolving the TODO in `op_kernel.h`. #### Background The `OpKernel` class previously had two virtual methods for consuming shared pre-packed weight buffers: - **`UseSharedPrePackedBuffers`** (V1) — 3 params: `prepacked_buffers`, `input_idx`, `used_shared_buffers` - **`UseSharedPrePackedBuffers_V2`** — 4 params: added `prepacked_buffer_sizes` (a `gsl::span`) V2 was introduced to pass buffer sizes alongside the buffers. Its default implementation forwarded to V1 for backward compatibility. The framework (`session_state.cc`) only ever called V2. #### Changes Merged both methods into a single `UseSharedPrePackedBuffers` using the V2 signature: ```cpp virtual Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, /*out*/ bool& used_shared_buffers); ``` Updated **27 files** across the codebase: | Category | Files | Change | |----------|-------|--------| | Base class | `op_kernel.h` | Removed V1 + V2; single 4-param method | | Framework | `session_state.cc` | Renamed `_V2` call | | Plugin EP bridge | `ep_kernel_registration.cc` | Renamed override | | QMoECPU | `moe_quantization_cpu.h/.cc` | Renamed V2 override + template instantiations | | CPU provider (8 kernels) | `gemm`, `matmul`, `conv_transpose`, `fp16_conv`, `qlinearconv`, `matmul_integer_base`, `deep_cpu_lstm`, `deep_cpu_gru` | Added `prepacked_buffer_sizes` param | | ACL provider (2 kernels) | `acl/conv`, `acl/matmul` | Added param | | Contrib ops (4 kernels) | `matmul_nbits`, `dynamic_quantize_lstm`, `attention_quant`, `bert/attention` | Added param | | Tests | `session_state_test.cc` | Updated test kernel override | #### Notes - Existing V1 overrides add the new `prepacked_buffer_sizes` parameter as **unnamed/unused** (`/*prepacked_buffer_sizes*/`) — no logic changes in those kernels. - The C API (`SetSharedPrePackedWeight` in `onnxruntime_ep_c_api.h`) already passes buffer sizes, so **no C API changes** were needed. - Private helper functions (e.g., `UseSharedPrePackedBuffersImpl` in LSTM/GRU) are not virtual overrides and were **not modified**. ### Motivation and Context Addresses the TODO at `include/onnxruntime/core/framework/op_kernel.h:139`: > TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function, which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu). --- .../onnxruntime/core/framework/op_kernel.h | 25 +++---------------- onnxruntime/contrib_ops/cpu/bert/attention.cc | 2 ++ .../cpu/moe/moe_quantization_cpu.cc | 12 ++++----- .../cpu/moe/moe_quantization_cpu.h | 8 +++--- .../cpu/quantization/attention_quant.cc | 2 ++ .../cpu/quantization/dynamic_quantize_lstm.cc | 2 ++ .../cpu/quantization/matmul_nbits.cc | 8 ++++-- onnxruntime/core/framework/session_state.cc | 4 +-- onnxruntime/core/providers/acl/math/matmul.cc | 1 + onnxruntime/core/providers/acl/math/matmul.h | 1 + onnxruntime/core/providers/acl/nn/conv.cc | 1 + onnxruntime/core/providers/acl/nn/conv.h | 1 + .../core/providers/cpu/fp16/fp16_conv.cc | 2 ++ onnxruntime/core/providers/cpu/math/gemm.cc | 2 ++ onnxruntime/core/providers/cpu/math/gemm.h | 1 + onnxruntime/core/providers/cpu/math/matmul.cc | 1 + onnxruntime/core/providers/cpu/math/matmul.h | 4 ++- .../core/providers/cpu/nn/conv_transpose.cc | 2 ++ .../core/providers/cpu/nn/conv_transpose.h | 1 + .../cpu/quantization/matmul_integer_base.h | 1 + .../providers/cpu/quantization/qlinearconv.cc | 2 ++ .../core/providers/cpu/rnn/deep_cpu_gru.cc | 1 + .../core/providers/cpu/rnn/deep_cpu_gru.h | 1 + .../core/providers/cpu/rnn/deep_cpu_lstm.cc | 1 + .../core/providers/cpu/rnn/deep_cpu_lstm.h | 1 + .../plugin_ep/ep_kernel_registration.cc | 6 ++--- .../test/framework/session_state_test.cc | 1 + 27 files changed, 54 insertions(+), 40 deletions(-) diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 8ec94c67cc0a4..42e8e9c5e3cbe 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -105,9 +105,9 @@ class OpKernel { return Status::OK(); } - // Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead. // Override this function to use provided pre-packed weight. // Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + // gsl::span prepacked_buffer_sizes, // int input_idx, // /*out*/ bool& used_shared_buffers) { // used_shared_buffers = true; @@ -121,37 +121,18 @@ class OpKernel { // and must use the same order for retrieval in UseSharedPrePackedBuffers(). Though each element // of this vector is a BufferUniquePtr, the deleter of the BufferUniquePtr is NULL. So actually they // are raw pointers. + // @param prepacked_buffer_sizes: The sizes (in bytes) of each buffer in prepacked_buffers. // @param input_idx: The input index of the tensor in this kernel // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating // that the provided weight has been used by the kernel. virtual Status UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; return Status::OK(); } - /// - /// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter. - /// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers() - /// to avoid the need to update all existing kernel-based provider-bridge EPs. - /// - /// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function, - /// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu). - /// - /// - /// - /// - /// - /// - /// - virtual Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span /*prepacked_buffer_sizes*/, - int input_idx, - /*out*/ bool& used_shared_buffers) { - return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers); - } - const OrtDevice GetDevice(OrtMemType mem_type) const; const OpKernelInfo& Info() const { return *op_kernel_info_; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 7268b32623b95..e1981fb5c2442 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -34,6 +34,7 @@ class Attention : public OpKernel, public AttentionCPUBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -176,6 +177,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr template Status Attention::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index ca2c3ab001da6..a674d05b6daae 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -578,10 +578,10 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all } template -Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span /*prepacked_buffer_sizes*/, - int input_idx, - /*out*/ bool& used_shared_buffers) { +Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (expert_weight_bits_ != 4) { @@ -1577,11 +1577,11 @@ template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index f678a27190c90..c5e6904ae48c2 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -32,10 +32,10 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, - gsl::span prepacked_buffer_sizes, - int input_idx, - /*out*/ bool& used_shared_buffers) override; + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; void ApplyActivationVectorized(float* data, int64_t size) const; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index b30fa1e5e618a..931677582d469 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -28,6 +28,7 @@ class QAttention : public OpKernel, public AttentionCPUBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -117,6 +118,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr template Status QAttention::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (1 != input_idx) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc index f55e66f9c5d81..2094af78f40b7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc @@ -17,6 +17,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -117,6 +118,7 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, Allocat } Status DynamicQuantizeLSTM::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index d2996b122c5f7..3da0ee19d4cde 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -135,7 +135,9 @@ class MatMulNBits final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) override; private: @@ -557,7 +559,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou #endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64 template -Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 5c33a621cf514..84521af2d8532 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -436,8 +436,8 @@ static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx, } bool used_shared_buffers = false; - ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes, - input_idx, used_shared_buffers)); + ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, shared_prepacked_buffer_sizes, + input_idx, used_shared_buffers)); // BUG CHECK: Ensure that the kernel used the provided shared buffers // Mostly a debug check to ensure that the kernel has an overridden implementation of the diff --git a/onnxruntime/core/providers/acl/math/matmul.cc b/onnxruntime/core/providers/acl/math/matmul.cc index 468b394471c13..029a9ebe2768a 100644 --- a/onnxruntime/core/providers/acl/math/matmul.cc +++ b/onnxruntime/core/providers/acl/math/matmul.cc @@ -269,6 +269,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (input_idx != 1) { diff --git a/onnxruntime/core/providers/acl/math/matmul.h b/onnxruntime/core/providers/acl/math/matmul.h index b137e33833de9..783e15585ebf5 100644 --- a/onnxruntime/core/providers/acl/math/matmul.h +++ b/onnxruntime/core/providers/acl/math/matmul.h @@ -34,6 +34,7 @@ class MatMul : public OpKernel { bool& is_packed, PrePackedWeights*) override; Status UseSharedPrePackedBuffers(std::vector&, + gsl::span, int, bool&) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index a62158f1c26ee..5cc10f7cfd2a8 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -370,6 +370,7 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } Status Conv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index b05ba5363542f..7af086a410857 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -36,6 +36,7 @@ class Conv : public onnxruntime::OpKernel { bool& is_packed, PrePackedWeights*) override; Status UseSharedPrePackedBuffers(std::vector&, + gsl::span, int, bool&) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 790b1543bbd74..08dbc46213f65 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -54,6 +54,7 @@ class FusedConvFp16 final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -211,6 +212,7 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (input_idx != 1) { diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index ac931c76ee3ae..c0da9aec1e1b1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -296,6 +296,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, template Status Gemm::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; @@ -304,6 +305,7 @@ Status Gemm::UseSharedPrePackedBuffers(std::vector& /*prepac template <> Status Gemm::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index c65f3eb96f62e..d9e66df4bee7c 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -37,6 +37,7 @@ class Gemm : protected GemmBase, public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8a7795a81027d..8dea41e3488e2 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -220,6 +220,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc } Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index 7f2d2ee400b63..9e6ef1a486235 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -47,7 +47,9 @@ class MatMul final : public OpKernel { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, /*out*/ bool& used_shared_buffers) override; Status Compute(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index 6ebd12a525371..bbb530d037cec 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -102,6 +102,7 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Alloca template Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + gsl::span /*prepacked_buffer_sizes*/, int /*input_idx*/, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; @@ -110,6 +111,7 @@ Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& template <> Status ConvTranspose::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index fd6021e65670e..96e3ecf912f32 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -35,6 +35,7 @@ class ConvTranspose : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index fb86e9731035c..9916c426a54fe 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -80,6 +80,7 @@ class MatMulIntegerBase : public OpKernel { } Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 24c8b0d57294e..a5e3d4b04a1e3 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -30,6 +30,7 @@ class QLinearConv : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; @@ -495,6 +496,7 @@ Status QLinearConv::PrePack(const Tensor& tensor, int input_idx, Alloca template Status QLinearConv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { if (input_idx != 3) { diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index d1ddd04a953ef..d5be6bd29592e 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -322,6 +322,7 @@ Status DeepCpuGruOp::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr a } Status DeepCpuGruOp::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h index 881adf9efb376..fa233cc6f9cde 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.h @@ -69,6 +69,7 @@ class DeepCpuGruOp final : public OpKernel { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index 4b3ea672c0812..d2520804bb64c 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -260,6 +260,7 @@ Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, } Status DeepCpuLstmOp::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h index c949b62ce7186..487e2a3fb8129 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h @@ -24,6 +24,7 @@ class DeepCpuLstmOp final : public OpKernel, public LSTMBase { /*out*/ PrePackedWeights* prepacked_weights) override; Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override; diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc index 625645e71cfec..6f29361502a73 100644 --- a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -126,9 +126,9 @@ class PluginEpOpKernel final : public controlflow::IControlFlowKernel { return Status::OK(); } - Status UseSharedPrePackedBuffers_V2(std::vector& buffer_unique_ptrs, - gsl::span buffer_sizes, - int input_idx, /*out*/ bool& used_shared_buffers) override { + Status UseSharedPrePackedBuffers(std::vector& buffer_unique_ptrs, + gsl::span buffer_sizes, + int input_idx, /*out*/ bool& used_shared_buffers) override { assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). if (kernel_impl_->ort_version_supported < 24 || kernel_impl_->SetSharedPrePackedWeight == nullptr) { diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 656b0ef86289d..418bb2a809259 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -662,6 +662,7 @@ class PrePackingTestOpKernel : public OpKernel { } Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, int input_idx, /*out*/ bool& used_shared_buffers) override { ORT_UNUSED_PARAMETER(input_idx); From c9726045b346a9005901a4f714e0824203a878dc Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 3 Apr 2026 11:40:42 -0700 Subject: [PATCH 5/5] [Core] MobileClip Attention Fusion (#27883) ### Description Update the Attention Fusion optimizer to help fuse the Attention subgraph pattern in MobileClip model. The perf gain from this itself is paltry (mostly from not having to launch many kernels) but the real gain will be AFTER this fusion (i.e.) tuning the performance of the MHA kernel for the problem shapes seen in this model. There are 2 Attention blocks found in the model and this update fuses both of them. ### Motivation and Context Improve performance of MobileClip model --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/optimizer/attention_fusion.cc | 666 +++++++++++++++++- .../test/optimizer/graph_transform_test.cc | 357 ++++++++++ 2 files changed, 1022 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 9fd71b3b00cd0..7fe7c914fa796 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -2,15 +2,673 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/utils.h" #include "core/optimizer/attention_fusion_helper.h" -#include "core/graph/graph_utils.h" #include +#include namespace onnxruntime { +static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size); + +namespace { + +static bool ValidateAddBiasInitializerEitherInput(const Graph& graph, const Node& add, int64_t hidden_size) { + if (add.InputDefs().size() < 2) { + return false; + } + + const NodeArg& input_0 = *(add.InputDefs()[0]); + const NodeArg& input_1 = *(add.InputDefs()[1]); + const bool input_0_is_bias = graph_utils::IsInitializer(graph, input_0.Name(), true) && + optimizer_utils::ValidateShape(input_0, {hidden_size}); + const bool input_1_is_bias = graph_utils::IsInitializer(graph, input_1.Name(), true) && + optimizer_utils::ValidateShape(input_1, {hidden_size}); + return input_0_is_bias || input_1_is_bias; +} + +static bool ValidateProjectionGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidden_size) { + if (gemm.InputDefs().size() < 3) { + return false; + } + + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* trans_a_attr = graph_utils::GetNodeAttribute(gemm, "transA"); + trans_a_attr && trans_a_attr->i() != 0) { + return false; + } + + if (const auto* trans_b_attr = graph_utils::GetNodeAttribute(gemm, "transB"); + trans_b_attr && trans_b_attr->i() != 0) { + return false; + } + + const NodeArg& input_b = *(gemm.InputDefs()[1]); + const NodeArg& input_c = *(gemm.InputDefs()[2]); + if (!graph_utils::IsInitializer(graph, input_b.Name(), true) || + !graph_utils::IsInitializer(graph, input_c.Name(), true)) { + return false; + } + + return optimizer_utils::ValidateShape(input_b, {hidden_size, hidden_size}) && + optimizer_utils::ValidateShape(input_c, {hidden_size}); +} + +// Most attention fusions require all matched nodes to already be assigned to an execution provider +// that supports the fused op. MobileClipMHA is also matched before partitioning in graph-transform +// tests, so nodes may still be unassigned here. Accept nodes that are either unassigned or already +// assigned to a compatible provider, and preserve the original provider string on the fused nodes +// once the pattern is rewritten. +static bool IsSupportedOrUnassignedNode(const Node& node, + const InlinedHashSet& compatible_execution_providers) { + return node.GetExecutionProviderType().empty() || + graph_utils::IsSupportedProvider(node, compatible_execution_providers); +} + +static bool IsSupportedOrUnassignedNode(const Node& node, + std::string_view required_execution_provider) { + const auto& execution_provider = node.GetExecutionProviderType(); + return execution_provider.empty() || + execution_provider == required_execution_provider; +} + +static bool AreSupportedOrUnassignedNodes( + const Node& anchor_node, + const std::initializer_list& nodes, + const InlinedHashSet& compatible_execution_providers) { + if (!IsSupportedOrUnassignedNode(anchor_node, compatible_execution_providers)) { + return false; + } + + const auto& required_execution_provider = anchor_node.GetExecutionProviderType(); + for (const Node* node : nodes) { + if (node == nullptr) { + continue; + } + + if (!IsSupportedOrUnassignedNode(*node, required_execution_provider)) { + return false; + } + } + + return true; +} + +static bool HasExpectedPerm(const Node& node, const std::initializer_list& expected_perm) { + return optimizer_utils::IsAttributeWithExpectedValues(node, "perm", std::vector(expected_perm)); +} + +static bool HasExpectedAxesInput(const Graph& graph, const Node& node, const std::initializer_list& expected_axes) { + if (node.InputDefs().size() < 2) { + return false; + } + + InlinedVector axes; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], axes, true)) { + return false; + } + + return axes == InlinedVector(expected_axes.begin(), expected_axes.end()); +} + +static bool TryGetMobileClipQkvReshapeInfo(const Graph& graph, const Node& qkv_reshape, + int64_t& num_heads, int64_t& head_size, int64_t& hidden_size) { + if (qkv_reshape.InputDefs().size() < 2) { + return false; + } + + InlinedVector reshape_dims; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *qkv_reshape.InputDefs()[1], reshape_dims, true)) { + return false; + } + + if (reshape_dims.size() != 5 || reshape_dims[2] != 3 || reshape_dims[3] <= 0 || reshape_dims[4] <= 0) { + return false; + } + + num_heads = reshape_dims[3]; + head_size = reshape_dims[4]; + + try { + hidden_size = SafeInt(num_heads) * head_size; + } catch (const OnnxRuntimeException&) { + return false; + } + + return hidden_size > 0; +} + +static std::optional TryCreateMobileClipMhaOutputType(const NodeArg& qkv_output, + int64_t hidden_size) { + const auto* qkv_output_type = qkv_output.TypeAsProto(); + if (qkv_output_type == nullptr || !qkv_output_type->has_tensor_type()) { + return std::nullopt; + } + + ONNX_NAMESPACE::TypeProto mha_output_type(*qkv_output_type); + auto* shape = mha_output_type.mutable_tensor_type()->mutable_shape(); + if (shape->dim_size() > 0) { + auto* last_dim = shape->mutable_dim(shape->dim_size() - 1); + last_dim->clear_dim_param(); + last_dim->set_dim_value(hidden_size); + } + + return mha_output_type; +} + +static Node* GetOnlyChildByOutputIndex(Graph& graph, const Node& node, size_t output_index, const char* child_op_type) { + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, output_index); + if (output_edges.size() != 1) { + return nullptr; + } + + Node* child = graph.GetNode(output_edges[0].dst_node); + if (child == nullptr || child->OpType() != child_op_type) { + return nullptr; + } + + return child; +} + +static bool TryCreateNormalizedProjectionGemm(Graph& graph, + NodeArg& projection_input, + const NodeArg& original_projection_input, + const NodeArg& proj_weight, + const NodeArg& proj_bias, + NodeArg& projection_output, + const std::string& base_name, + const std::string& provider_type) { + const auto* proj_input_shape = original_projection_input.Shape(); + const auto* proj_weight_shape = proj_weight.Shape(); + if (proj_input_shape == nullptr || proj_weight_shape == nullptr || proj_weight_shape->dim_size() != 2) { + return false; + } + + auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*proj_input_shape); + if (input_shape.Size() == -1 || input_shape.NumDimensions() < 2) { + return false; + } + + const auto& dim_k = proj_weight_shape->dim(0); + const auto& dim_n = proj_weight_shape->dim(1); + if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) { + return false; + } + + const int64_t m = input_shape.SizeToDimension(input_shape.NumDimensions() - 1); + if (m <= 0) { + return false; + } + + const int64_t k = dim_k.dim_value(); + const int64_t n = dim_n.dim_value(); + if (input_shape[input_shape.NumDimensions() - 1] != k) { + return false; + } + + const auto* bias_shape = proj_bias.Shape(); + if (bias_shape == nullptr || bias_shape->dim_size() != 1 || !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != n) { + return false; + } + + const auto* input_type = original_projection_input.TypeAsProto(); + if (input_type == nullptr || !input_type->has_tensor_type()) { + return false; + } + + const auto element_type = static_cast(input_type->tensor_type().elem_type()); + + auto add_shape_initializer = [&](const std::string& name, const InlinedVector& shape) -> NodeArg& { + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeArgName(name)); + shape_initializer_proto.add_dims(static_cast(shape.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + const size_t shape_bytes = SafeInt(shape.size()) * sizeof(int64_t); + utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape_bytes); + return graph_utils::AddInitializerWithOrtValue(graph, shape_initializer_proto); + }; + + auto make_tensor_arg = [&](const std::string& name, const InlinedVector& shape) -> NodeArg* { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t dim_value : shape) { + type_proto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim_value); + } + + return &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name), &type_proto); + }; + + InlinedVector gemm_input_shape{m, k}; + InlinedVector gemm_output_shape{m, n}; + InlinedVector output_shape_values = input_shape.AsShapeVector(); + output_shape_values.back() = n; + + NodeArg* gemm_input_arg = make_tensor_arg("mobileclip_proj_gemm_input", gemm_input_shape); + NodeArg* gemm_output_arg = make_tensor_arg("mobileclip_proj_gemm_output", gemm_output_shape); + NodeArg& gemm_input_shape_arg = add_shape_initializer("mobileclip_proj_gemm_input_shape", gemm_input_shape); + NodeArg& gemm_output_shape_arg = add_shape_initializer("mobileclip_proj_gemm_output_shape", output_shape_values); + + Node& input_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmInputReshape"), + "Reshape", + "Reshape MobileCLIP projection input for Gemm", + {&projection_input, &gemm_input_shape_arg}, + {gemm_input_arg}); + input_reshape.SetExecutionProviderType(provider_type); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeName(base_name + "/MobileClipProjectionGemm"), + "Gemm", + "Normalized MobileCLIP projection Gemm", + {gemm_input_arg, const_cast(&proj_weight), const_cast(&proj_bias)}, + {gemm_output_arg}); + gemm_node.SetExecutionProviderType(provider_type); + + Node& output_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmOutputReshape"), + "Reshape", + "Restore MobileCLIP projection output shape after Gemm", + {gemm_output_arg, &gemm_output_shape_arg}, + {&projection_output}); + output_reshape.SetExecutionProviderType(provider_type); + + return true; +} + +static bool TryRewriteProjectionMatMulAddToGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_matmul, + Node& proj_add) { + if (proj_matmul.InputDefs().size() < 2 || proj_add.InputDefs().size() < 2) { + return false; + } + + const int bias_idx = proj_matmul.OutputDefs()[0]->Name() == proj_add.InputDefs()[0]->Name() ? 1 : 0; + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_matmul.InputDefs()[0], + *proj_matmul.InputDefs()[1], + *proj_add.InputDefs()[bias_idx], + *proj_add.MutableOutputDefs()[0], + proj_matmul.Name(), + proj_matmul.GetExecutionProviderType()); +} + +static bool TryRewriteProjectionGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_gemm) { + if (proj_gemm.InputDefs().size() < 3 || proj_gemm.OutputDefs().empty()) { + return false; + } + + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_gemm.InputDefs()[0], + *proj_gemm.InputDefs()[1], + *proj_gemm.InputDefs()[2], + *proj_gemm.MutableOutputDefs()[0], + proj_gemm.Name(), + proj_gemm.GetExecutionProviderType()); +} + +static bool TryFuseMobileClipMHA(Node& qkv_matmul, + Graph& graph, + const InlinedHashSet& compatible_execution_providers, + const logging::Logger& logger) { + const auto fail = [&](const char* message) { + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() << "]: fusion skipped: " << message; + return false; + }; + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(qkv_matmul, "MatMul", {1, 9, 13}, kOnnxDomain)) { + return false; + } + + if (!IsSupportedOrUnassignedNode(qkv_matmul, compatible_execution_providers)) { + return false; + } + + if (!optimizer_utils::CheckOutputEdges(graph, qkv_matmul, 1) || qkv_matmul.InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, qkv_matmul.InputDefs()[1]->Name(), true)) { + return fail("qkv MatMul output count or weight initializer check failed"); + } + + const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0); + if (sequence_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*sequence_transpose, {0, 2, 1}) || + !optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) { + return false; + } + + const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0); + if (input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) { + return fail("missing input Reshape before sequence transpose"); + } + + Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape"); + if (qkv_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) { + return fail("qkv Reshape after MatMul not matched"); + } + + Node* split = GetOnlyChildByOutputIndex(graph, *qkv_reshape, 0, "Split"); + if (split == nullptr || !graph_utils::IsSupportedOptypeVersionAndDomain(*split, "Split", {13, 18}, kOnnxDomain) || + split->OutputDefs().size() != 3 || !optimizer_utils::IsAttributeWithExpectedValue(*split, "axis", static_cast(2))) { + return fail("qkv Split(axis=2, outputs=3) not matched"); + } + + Node* q_transpose = GetOnlyChildByOutputIndex(graph, *split, 0, "Transpose"); + Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze"); + Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose"); + if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedAxesInput(graph, *k_squeeze, {2})) { + return fail("q/k/v branch entry pattern after Split not matched"); + } + + Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze"); + Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze"); + if (q_squeeze == nullptr || v_squeeze == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) || + !HasExpectedAxesInput(graph, *q_squeeze, {0}) || + !HasExpectedAxesInput(graph, *v_squeeze, {0})) { + return fail("q/v squeeze pattern not matched"); + } + + Node* q_scale_mul = GetOnlyChildByOutputIndex(graph, *q_squeeze, 0, "Mul"); + Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose"); + if (q_scale_mul == nullptr || k_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) { + return fail("q scale Mul or k Transpose(0,2,3,1) not matched"); + } + + float scale = 0.0f; + if (q_scale_mul->InputDefs().size() < 2) { + return fail("q scale constant not found"); + } + + const NodeArg* q_squeeze_output = q_squeeze->OutputDefs()[0]; + const NodeArg* mul_input_0 = q_scale_mul->InputDefs()[0]; + const NodeArg* mul_input_1 = q_scale_mul->InputDefs()[1]; + const bool input_0_is_q_squeeze = mul_input_0 != nullptr && q_squeeze_output != nullptr && + mul_input_0->Name() == q_squeeze_output->Name(); + const bool input_1_is_q_squeeze = mul_input_1 != nullptr && q_squeeze_output != nullptr && + mul_input_1->Name() == q_squeeze_output->Name(); + + const NodeArg* scale_input = nullptr; + if (input_0_is_q_squeeze && !input_1_is_q_squeeze) { + scale_input = mul_input_1; + } else if (input_1_is_q_squeeze && !input_0_is_q_squeeze) { + scale_input = mul_input_0; + } + + if (scale_input == nullptr || + !optimizer_utils::GetScalarInitializerValue(graph, *scale_input, scale, true)) { + return fail("q scale constant not found"); + } + + Node* qk_matmul = GetOnlyChildByOutputIndex(graph, *q_scale_mul, 0, "MatMul"); + if (qk_matmul == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qk_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qk_matmul, 1) == nullptr || + graph_utils::GetInputNode(*qk_matmul, 1)->Index() != k_transpose->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qk_matmul, 1)) { + return fail("qk MatMul not matched"); + } + + Node* softmax = GetOnlyChildByOutputIndex(graph, *qk_matmul, 0, "Softmax"); + if (softmax == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*softmax, "Softmax", {1, 11, 13}, kOnnxDomain) || + !optimizer_utils::IsAttributeWithExpectedValue(*softmax, "axis", static_cast(-1)) || + !optimizer_utils::CheckOutputEdges(graph, *softmax, 1)) { + return fail("Softmax(axis=-1) not matched"); + } + + Node* qkv_matmul_1 = GetOnlyChildByOutputIndex(graph, *softmax, 0, "MatMul"); + if (qkv_matmul_1 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_matmul_1, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qkv_matmul_1, 1) == nullptr || + graph_utils::GetInputNode(*qkv_matmul_1, 1)->Index() != v_squeeze->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qkv_matmul_1, 1)) { + return fail("attention-value MatMul not matched"); + } + + Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose"); + if (transpose_3 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) || + !optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) { + return fail("output Transpose(0,2,1,3) not matched"); + } + + Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape"); + if (reshape_2 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) { + return fail("output Reshape not matched"); + } + + Node* proj_matmul = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "MatMul"); + Node* proj_gemm = proj_matmul == nullptr ? GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Gemm") : nullptr; + Node* proj_gemm_input_reshape = nullptr; + Node* proj_gemm_output_reshape = nullptr; + Node* proj_add = nullptr; + + if (proj_matmul != nullptr) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + proj_matmul->InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, proj_matmul->InputDefs()[1]->Name(), true) || + !optimizer_utils::CheckOutputEdges(graph, *proj_matmul, 1)) { + return fail("projection MatMul not matched"); + } + + proj_add = GetOnlyChildByOutputIndex(graph, *proj_matmul, 0, "Add"); + if (proj_add == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_add, "Add", {7, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_add, 1)) { + return fail("projection Add not matched"); + } + } else { + if (proj_gemm == nullptr) { + proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape"); + if (proj_gemm_input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm = GetOnlyChildByOutputIndex(graph, *proj_gemm_input_reshape, 0, "Gemm"); + if (proj_gemm == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape"); + if (proj_gemm_output_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) { + return fail("normalized projection Gemm output Reshape not matched"); + } + } else if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + } + + int64_t num_heads = 0; + int64_t head_size = 0; + int64_t hidden_size = 0; + if (!TryGetMobileClipQkvReshapeInfo(graph, *qkv_reshape, num_heads, head_size, hidden_size)) { + return fail("unable to derive num_heads/head_size from qkv reshape initializer"); + } + + if (proj_matmul != nullptr) { + if (!ValidateMatMulInitializer(graph, *proj_matmul, hidden_size) || + !ValidateAddBiasInitializerEitherInput(graph, *proj_add, hidden_size)) { + return fail("projection weight/bias shape validation failed"); + } + } else { + if (!ValidateProjectionGemmInitializer(graph, *proj_gemm, hidden_size)) { + return fail("projection Gemm weight/bias shape validation failed"); + } + } + + const NodeArg& qkv_weight = *qkv_matmul.InputDefs()[1]; + if (!optimizer_utils::ValidateShape(qkv_weight, {hidden_size, 3 * hidden_size})) { + return fail("qkv weight shape is not [hidden, 3*hidden]"); + } + + if (!AreSupportedOrUnassignedNodes( + qkv_matmul, + {sequence_transpose, + input_reshape, + qkv_reshape, + split, + q_transpose, + k_squeeze, + v_transpose, + q_squeeze, + v_squeeze, + q_scale_mul, + k_transpose, + qk_matmul, + softmax, + qkv_matmul_1, + transpose_3, + reshape_2, + proj_matmul, + proj_add, + proj_gemm_input_reshape, + proj_gemm, + proj_gemm_output_reshape}, + compatible_execution_providers)) { + return fail("matched nodes are assigned to incompatible execution providers"); + } + + auto mha_output_type = TryCreateMobileClipMhaOutputType(*qkv_matmul.OutputDefs()[0], hidden_size); + auto* mha_output = &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("mobileclip_mha_output"), + mha_output_type ? &*mha_output_type : nullptr); + + if (proj_matmul != nullptr) { + if (!TryRewriteProjectionMatMulAddToGemm(graph, *mha_output, *proj_matmul, *proj_add)) { + return fail("projection MatMul/Add could not be rewritten to Gemm"); + } + } else if (proj_gemm_input_reshape == nullptr) { + if (!TryRewriteProjectionGemm(graph, *mha_output, *proj_gemm)) { + return fail("projection Gemm could not be normalized"); + } + } + + ONNX_NAMESPACE::TensorProto split_sizes_tensor; + split_sizes_tensor.set_name(graph.GenerateNodeArgName("mobileclip_mha_split_sizes")); + split_sizes_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_sizes_tensor.add_dims(3); + const std::array split_sizes{hidden_size, hidden_size, hidden_size}; + utils::SetRawDataInTensorProto(split_sizes_tensor, split_sizes.data(), split_sizes.size() * sizeof(int64_t)); + NodeArg& split_sizes_arg = graph_utils::AddInitializerWithOrtValue(graph, split_sizes_tensor); + + auto* mha_q = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_q"), nullptr); + auto* mha_k = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_k"), nullptr); + auto* mha_v = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_v"), nullptr); + + Node& split_for_mha = graph.AddNode( + graph.GenerateNodeName("MobileClipSplitForMHA"), + "Split", + "Split packed MobileCLIP QKV for MultiHeadAttention", + {qkv_matmul.MutableOutputDefs()[0], &split_sizes_arg}, + {mha_q, mha_k, mha_v}, + nullptr, + kOnnxDomain); + split_for_mha.AddAttribute("axis", static_cast(2)); + + Node& mha_node = graph.AddNode( + graph.GenerateNodeName("MobileClipMultiHeadAttention"), + "MultiHeadAttention", + "Fused MobileCLIP attention subgraph", + {mha_q, mha_k, mha_v}, + {mha_output}, + nullptr, + kMSDomain); + mha_node.AddAttribute("num_heads", num_heads); + mha_node.AddAttribute("scale", scale); + + const auto& provider = qkv_matmul.GetExecutionProviderType(); + split_for_mha.SetExecutionProviderType(provider); + mha_node.SetExecutionProviderType(provider); + + if (proj_gemm_input_reshape != nullptr) { + graph_utils::ReplaceDownstreamNodeInput(graph, *reshape_2, 0, mha_node, 0); + } + + std::vector nodes_to_remove{ + qkv_reshape->Index(), + split->Index(), + q_transpose->Index(), + q_squeeze->Index(), + q_scale_mul->Index(), + k_squeeze->Index(), + k_transpose->Index(), + qk_matmul->Index(), + softmax->Index(), + v_transpose->Index(), + v_squeeze->Index(), + qkv_matmul_1->Index(), + transpose_3->Index(), + reshape_2->Index(), + }; + + if (proj_matmul != nullptr) { + nodes_to_remove.push_back(proj_matmul->Index()); + nodes_to_remove.push_back(proj_add->Index()); + } else if (proj_gemm_input_reshape == nullptr) { + nodes_to_remove.push_back(proj_gemm->Index()); + } + + for (const auto& node_index : nodes_to_remove) { + Node* node = graph.GetNode(node_index); + if (node == nullptr) { + continue; + } + + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node_index); + } + + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() + << "]: fused MobileCLIP attention subgraph to MultiHeadAttention"; + + return true; +} + +} // namespace + static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size) { const NodeArg& input_b = *(matmul.InputDefs()[1]); if (!graph_utils::IsInitializer(graph, input_b.Name(), true)) { @@ -179,6 +837,12 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (TryFuseMobileClipMHA(node, graph, GetCompatibleExecutionProviders(), logger)) { + fused_count++; + modified = true; + continue; + } + // Add node.GetOutputEdgesCount() == 5/6 for distilbert if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 18933e45b8922..75ba3b802f9ae 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5826,6 +5826,363 @@ TEST_F(GraphTransformationTests, AttentionFusionDistilBertTest) { EXPECT_EQ(op_to_count["Shape"], 0); } +enum class MobileClipProjectionType { + MatMulAdd, + GemmWithReshapes, +}; + +struct MobileClipAttentionShapeConfig { + int64_t input_channels = 512; + int64_t hidden_size = 512; + int64_t num_heads = 16; + int64_t head_size = 32; + int64_t qkv_weight_input_dim = 512; +}; + +static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder, + MobileClipProjectionType projection_type, + const MobileClipAttentionShapeConfig& shape_config = {}, + bool use_non_default_projection_gemm_attributes = false, + bool use_runtime_projection_shape_input = false) { + const int64_t input_channels = shape_config.input_channels; + const int64_t hidden_size = shape_config.hidden_size; + const int64_t num_heads = shape_config.num_heads; + const int64_t head_size = shape_config.head_size; + const int64_t qkv_weight_input_dim = shape_config.qkv_weight_input_dim; + const int64_t qkv_hidden_size = num_heads * head_size; + const int64_t qkv_output_size = 3 * qkv_hidden_size; + + auto* input_x = builder.MakeInput({1, input_channels, 8, 8}, -1.0f, 1.0f); + auto* input_skip = builder.MakeInput({1, hidden_size, 8, 8}, -1.0f, 1.0f); + + auto* reshape0_shape = builder.Make1DInitializer({1, input_channels, 64}); + auto* qkv_weight = builder.MakeInitializer({qkv_weight_input_dim, qkv_output_size}, -0.05f, 0.05f); + auto* qkv_reshape_shape = builder.Make1DInitializer({1, 64, 3, num_heads, head_size}); + auto* split_sizes = builder.Make1DInitializer({1, 1, 1}); + auto* squeeze_axis_0 = builder.Make1DInitializer({0}); + auto* squeeze_axis_2 = builder.Make1DInitializer({2}); + auto* scale = builder.MakeScalarInitializer(1.0f / std::sqrt(static_cast(head_size))); + auto* reshape2_shape = use_runtime_projection_shape_input + ? builder.MakeInput({3}, {1, 64, hidden_size}) + : builder.Make1DInitializer({1, 64, hidden_size}); + auto* proj_gemm_input_shape = builder.Make1DInitializer({64, hidden_size}); + auto* proj_weight = builder.MakeInitializer({hidden_size, hidden_size}, -0.05f, 0.05f); + auto* proj_bias = builder.MakeInitializer({hidden_size}, -0.02f, 0.02f); + auto* proj_gemm_output_shape = builder.Make1DInitializer({1, 64, hidden_size}); + auto* reshape3_shape = builder.Make1DInitializer({1, hidden_size, 8, 8}); + auto* layer_scale = builder.MakeInitializer({hidden_size, 1, 1}, 0.9f, 1.1f); + + auto* reshape0_out = builder.MakeIntermediate(std::vector{1, input_channels, 64}); + auto* transpose0_out = builder.MakeIntermediate(std::vector{1, 64, input_channels}); + auto* qkv_out = builder.MakeIntermediate(std::vector{1, 64, qkv_output_size}); + auto* qkv_reshape_out = builder.MakeIntermediate(std::vector{1, 64, 3, num_heads, head_size}); + auto* split_q = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_k = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_v = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* q_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* q_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* k_squeeze_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* k_transpose_out = builder.MakeIntermediate(std::vector{1, num_heads, head_size, 64}); + auto* q_scaled_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* qk_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* softmax_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* v_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* v_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* attn_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* transpose3_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* reshape2_out = use_runtime_projection_shape_input + ? builder.MakeIntermediate(std::nullopt) + : builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* proj_gemm_input_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_gemm_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_linear_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* transpose4_out = builder.MakeIntermediate(std::vector{1, hidden_size, 64}); + auto* reshape3_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* layer_scale_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* output = builder.MakeOutput(std::vector{1, hidden_size, 8, 8}); + + auto& reshape0 = builder.AddNode("Reshape", std::vector{input_x, reshape0_shape}, std::vector{reshape0_out}); + reshape0.AddAttribute("allowzero", static_cast(0)); + + auto& transpose0 = builder.AddNode("Transpose", std::vector{reshape0_out}, std::vector{transpose0_out}); + transpose0.AddAttribute("perm", std::vector{0, 2, 1}); + + builder.AddNode("MatMul", std::vector{transpose0_out, qkv_weight}, std::vector{qkv_out}); + + auto& qkv_reshape = builder.AddNode("Reshape", std::vector{qkv_out, qkv_reshape_shape}, std::vector{qkv_reshape_out}); + qkv_reshape.AddAttribute("allowzero", static_cast(0)); + + auto& split = builder.AddNode("Split", std::vector{qkv_reshape_out, split_sizes}, std::vector{split_q, split_k, split_v}); + split.AddAttribute("axis", static_cast(2)); + + auto& q_transpose = builder.AddNode("Transpose", std::vector{split_q}, std::vector{q_transpose_out}); + q_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{q_transpose_out, squeeze_axis_0}, std::vector{q_squeeze_out}); + builder.AddNode("Squeeze", std::vector{split_k, squeeze_axis_2}, std::vector{k_squeeze_out}); + + auto& k_transpose = builder.AddNode("Transpose", std::vector{k_squeeze_out}, std::vector{k_transpose_out}); + k_transpose.AddAttribute("perm", std::vector{0, 2, 3, 1}); + + builder.AddNode("Mul", std::vector{q_squeeze_out, scale}, std::vector{q_scaled_out}); + builder.AddNode("MatMul", std::vector{q_scaled_out, k_transpose_out}, std::vector{qk_out}); + + auto& softmax = builder.AddNode("Softmax", std::vector{qk_out}, std::vector{softmax_out}); + softmax.AddAttribute("axis", static_cast(-1)); + + auto& v_transpose = builder.AddNode("Transpose", std::vector{split_v}, std::vector{v_transpose_out}); + v_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{v_transpose_out, squeeze_axis_0}, std::vector{v_squeeze_out}); + builder.AddNode("MatMul", std::vector{softmax_out, v_squeeze_out}, std::vector{attn_out}); + + auto& transpose3 = builder.AddNode("Transpose", std::vector{attn_out}, std::vector{transpose3_out}); + transpose3.AddAttribute("perm", std::vector{0, 2, 1, 3}); + + auto& reshape2 = builder.AddNode("Reshape", std::vector{transpose3_out, reshape2_shape}, std::vector{reshape2_out}); + reshape2.AddAttribute("allowzero", static_cast(0)); + + if (projection_type == MobileClipProjectionType::GemmWithReshapes) { + auto& proj_gemm_input = builder.AddNode("Reshape", std::vector{reshape2_out, proj_gemm_input_shape}, + std::vector{proj_gemm_input_out}); + proj_gemm_input.AddAttribute("allowzero", static_cast(0)); + + auto& proj_gemm = builder.AddNode("Gemm", std::vector{proj_gemm_input_out, proj_weight, proj_bias}, + std::vector{proj_gemm_out}); + if (use_non_default_projection_gemm_attributes) { + proj_gemm.AddAttribute("transB", static_cast(1)); + } + + auto& proj_gemm_output = builder.AddNode("Reshape", std::vector{proj_gemm_out, proj_gemm_output_shape}, + std::vector{proj_linear_out}); + proj_gemm_output.AddAttribute("allowzero", static_cast(0)); + } else { + auto* proj_matmul_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + builder.AddNode("MatMul", std::vector{reshape2_out, proj_weight}, std::vector{proj_matmul_out}); + builder.AddNode("Add", std::vector{proj_bias, proj_matmul_out}, std::vector{proj_linear_out}); + } + + auto& transpose4 = builder.AddNode("Transpose", std::vector{proj_linear_out}, std::vector{transpose4_out}); + transpose4.AddAttribute("perm", std::vector{0, 2, 1}); + + auto& reshape3 = builder.AddNode("Reshape", std::vector{transpose4_out, reshape3_shape}, std::vector{reshape3_out}); + reshape3.AddAttribute("allowzero", static_cast(0)); + + builder.AddNode("Mul", std::vector{layer_scale, reshape3_out}, std::vector{layer_scale_out}); + builder.AddNode("Add", std::vector{input_skip, layer_scale_out}, std::vector{output}); +} + +static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int mha_nodes = 0; + int gemm_nodes = 0; + int split_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) { + ++mha_nodes; + TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16); + TEST_RETURN_IF_NOT(std::abs(node.GetAttributes().at("scale").f() - (1.0f / std::sqrt(32.0f))) < 1e-6f); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 3); + } else if (node.OpType() == "Split") { + ++split_nodes; + } else if (node.OpType() == "Gemm") { + ++gemm_nodes; + TEST_RETURN_IF_NOT(node.InputDefs().size() == 3); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 2); + + const Node* gemm_input_node = graph_utils::GetInputNode(node, 0); + TEST_RETURN_IF_NOT(gemm_input_node != nullptr); + TEST_RETURN_IF_NOT(gemm_input_node->OpType() == "Reshape"); + + bool has_output_reshape = false; + for (const Node& consumer : graph.Nodes()) { + for (const NodeArg* input_def : consumer.InputDefs()) { + if (input_def != nullptr && input_def->Name() == node.OutputDefs()[0]->Name()) { + has_output_reshape = consumer.OpType() == "Reshape"; + break; + } + } + + if (has_output_reshape) { + break; + } + } + + TEST_RETURN_IF_NOT(has_output_reshape); + } + } + + TEST_RETURN_IF_NOT(mha_nodes == 1); + TEST_RETURN_IF_NOT(gemm_nodes == 1); + TEST_RETURN_IF_NOT(split_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) { + ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph)); + + for (Node& node : graph.Nodes()) { + TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider); + } + + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int gemm_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() != "Gemm") { + continue; + } + + ++gemm_nodes; + const auto& attrs = node.GetAttributes(); + auto trans_b_attr = attrs.find("transB"); + TEST_RETURN_IF_NOT(trans_b_attr != attrs.end()); + TEST_RETURN_IF_NOT(trans_b_attr->second.i() == 1); + } + + TEST_RETURN_IF_NOT(gemm_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedMatMulGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 2); + return Status::OK(); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, + MobileClipProjectionType::MatMulAdd, + MobileClipAttentionShapeConfig{512, 510, 15, 34, 512}); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionUnfusedMatMulGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmNonDefaultAttributesTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes, {}, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedProjectionGemmGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionRewriteFailureLeavesGraphUnfusedTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd, {}, false, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedMatMulGraph)); +} + TEST_F(GraphTransformationTests, GeluFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model;