Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 3 additions & 22 deletions include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class OpKernel {
return Status::OK();
}

// Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead.
// Override this function to use provided pre-packed weight.
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
// gsl::span<const size_t> prepacked_buffer_sizes,
// int input_idx,
// /*out*/ bool& used_shared_buffers) {
// used_shared_buffers = true;
Expand All @@ -121,37 +121,18 @@ class OpKernel {
// and must use the same order for retrieval in UseSharedPrePackedBuffers(). Though each element
// of this vector is a BufferUniquePtr, the deleter of the BufferUniquePtr is NULL. So actually they
// are raw pointers.
// @param prepacked_buffer_sizes: The sizes (in bytes) of each buffer in prepacked_buffers.
// @param input_idx: The input index of the tensor in this kernel
// @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
// that the provided weight has been used by the kernel.
virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int /*input_idx*/,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
return Status::OK();
}

/// <summary>
/// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter.
/// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers()
/// to avoid the need to update all existing kernel-based provider-bridge EPs.
///
/// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function,
/// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu).
///
/// </summary>
/// <param name="prepacked_buffers"></param>
/// <param name="prepacked_buffer_sizes"></param>
/// <param name="input_idx"></param>
/// <param name="used_shared_buffers"></param>
/// <returns></returns>
virtual Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers);
}

const OrtDevice GetDevice(OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
Expand Down
17 changes: 5 additions & 12 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3613,12 +3613,6 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
};

namespace detail {
/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
*
* Holds a single type constraint from an operator schema, providing access to
* the constraint's name, allowed data types, and associated input/output indices.
* This is a non-owning view — the lifetime is tied to the parent OrtOpSchema.
*/
template <typename T>
struct OpSchemaTypeConstraintImpl : Base<T> {
using B = Base<T>;
Expand All @@ -3639,15 +3633,11 @@ struct OpSchemaTypeConstraintImpl : Base<T> {
} // namespace detail

/// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
/// Holds a single type constraint from an operator schema, providing access to
/// the constraint's name, allowed data types, and associated input/output indices.
using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl<detail::Unowned<const OrtOpSchemaTypeConstraint>>;

namespace detail {
/** \brief Owning wrapper around an `OrtOpSchema*`.
*
* Provides access to operator schema metadata such as version, input/output names,
* and type constraints. The underlying OrtOpSchema is owned by this wrapper and
* released automatically on destruction.
*/
template <typename T>
struct OpSchemaImpl : Base<T> {
using B = Base<T>;
Expand Down Expand Up @@ -3685,6 +3675,9 @@ struct OpSchemaImpl : Base<T> {
} // namespace detail

/// Owning wrapper around an `OrtOpSchema*`.
/// Provides access to operator schema metadata such as version, input/output names,
/// and type constraints. The underlying OrtOpSchema is owned by this wrapper and
/// released automatically on destruction.
using OpSchema = detail::OpSchemaImpl<OrtOpSchema>;

/// \brief Get an operator schema from the global schema registry.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Attention : public OpKernel, public AttentionCPUBase {
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) override;

Expand Down Expand Up @@ -176,6 +177,7 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr

template <typename T>
Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
if (1 != input_idx) {
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ Status QMoECPU<T>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all
}

template <typename T>
Status QMoECPU<T>::UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
Status QMoECPU<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

if (expert_weight_bits_ != 4) {
Expand Down Expand Up @@ -1577,11 +1577,11 @@ template QMoECPU<float>::QMoECPU(const OpKernelInfo& op_kernel_info);

template Status QMoECPU<float>::Compute(OpKernelContext* context) const;
template Status QMoECPU<float>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights);
template Status QMoECPU<float>::UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers, gsl::span<const size_t> prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers);
template Status QMoECPU<float>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, gsl::span<const size_t> prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers);
template QMoECPU<MLFloat16>::QMoECPU(const OpKernelInfo& op_kernel_info);
template Status QMoECPU<MLFloat16>::Compute(OpKernelContext* context) const;
template Status QMoECPU<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights);
template Status QMoECPU<MLFloat16>::UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers, gsl::span<const size_t> prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers);
template Status QMoECPU<MLFloat16>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, gsl::span<const size_t> prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers);

// Kernel Registration
ONNX_OPERATOR_TYPED_KERNEL_EX(
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU {
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> prepacked_buffer_sizes,
int input_idx,
/*out*/ bool& used_shared_buffers) override;
Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> prepacked_buffer_sizes,
int input_idx,
/*out*/ bool& used_shared_buffers) override;

void ApplyActivationVectorized(float* data, int64_t size) const;

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class QAttention : public OpKernel, public AttentionCPUBase {
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) override;

Expand Down Expand Up @@ -117,6 +118,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr

template <typename T>
Status QAttention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
if (1 != input_idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) override;

Expand Down Expand Up @@ -117,6 +118,7 @@ Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, Allocat
}

Status DynamicQuantizeLSTM::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ class MatMulNBits final : public OpKernel {
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;

Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) override;

private:
Expand Down Expand Up @@ -557,7 +559,9 @@ Status MatMulNBits<MLFloat16>::PrePack(const Tensor& tensor, int input_idx, /*ou
#endif // end !MLAS_F16VEC_INTRINSICS_SUPPORTED || !MLAS_TARGET_ARM64

template <typename T1>
Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx,
}

bool used_shared_buffers = false;
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes,
input_idx, used_shared_buffers));
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, shared_prepacked_buffer_sizes,
input_idx, used_shared_buffers));

// BUG CHECK: Ensure that the kernel used the provided shared buffers
// Mostly a debug check to ensure that the kernel has an overridden implementation of the
Expand Down
Loading
Loading