diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 9871c99f8ee..69c1f9919d4 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -166,16 +166,6 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { } break; } - case GGML_OP_CONT: { - if (node->src[0]->op == GGML_OP_PERMUTE) { - op_case = 1; - } else if (node->src[0]->op == GGML_OP_TRANSPOSE) { - op_case = 2; - } else if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 3; - } - break; - } case GGML_OP_PERMUTE: { if (node->src[0]->op != GGML_OP_VIEW) { op_case = 1; @@ -195,9 +185,7 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_MUL_MAT: { - if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) { - op_case = 2; - } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) { + if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) { op_case = 3; } break; @@ -314,6 +302,14 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr } break; } + // if the node op is TRANSPOSE and its input is PERMUTE and the source of the PERMUTE is VIEW, then get the attention size with the TRANSPOSE node ne[0] (in case no GGML_OP_FLASH_ATTN_EXT) + if (node->op == GGML_OP_TRANSPOSE && node->src[0]->op == GGML_OP_PERMUTE && + node->src[0]->src[0]->op == GGML_OP_VIEW) { + compute_params.attention_size = node->ne[0]; + if (is_static) { + compute_params.attention_size = model_params.ctx_per_seq; + } + } if (node->op == GGML_OP_ROPE) { memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15); } @@ -880,6 +876,11 @@ ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const { return get_ov_type(m_node_info_list[node_idx].node); } +std::vector GgmlOvDecoder::get_output_stride(int node_idx) const { + auto * ggml_tensor = m_node_info_list[node_idx].node; + return get_stride(ggml_tensor); +} + std::vector GgmlOvDecoder::get_output_names(int node_idx) const { return {m_node_info_list[node_idx].node_output_name}; } @@ -889,6 +890,14 @@ const std::string & GgmlOvDecoder::get_op_name() const { return unknown_name; } +int32_t GgmlOvDecoder::get_op_dynamic_dim(int node_idx) const { + auto it = m_node_dynamic_dims.find(m_node_info_list[node_idx].node); + if (it == m_node_dynamic_dims.end()) { + return -1; + } + return it->second; +} + const std::string & GgmlOvDecoder::get_op_name(int node_idx) const { return m_node_info_list[node_idx].node_name; } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index c793c3d6ae7..ef185dbd324 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -107,6 +107,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { virtual ov::element::Type get_output_type(int node_idx) const override; + virtual std::vector get_output_stride(int node_idx) const override; + virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override; virtual int32_t * get_output_op_params(int node_idx) const override; @@ -121,6 +123,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { virtual const std::string & get_op_name(int node_idx) const override; + virtual int32_t get_op_dynamic_dim(int node_idx) const override; + virtual void visit_subgraph(std::function, int node_idx)> node_visitor) const override; ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); } diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 149c16cde92..b613058206c 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -912,6 +912,14 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { } break; } + case GGML_OP_TRANSPOSE: { + // if the type is bf16, will return true + if (op->type == GGML_TYPE_BF16) { + // GGML_LOG_WARN("OpenVINO backend does not support CONT with BF16 type\n"); + return true; + } + break; + } default: break; } @@ -933,7 +941,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K}; static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW, - /*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, + GGML_OP_CONT, GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE, // softmax is not updated due to replaced by flash_attn_ext // GGML_OP_SOFT_MAX, diff --git a/ggml/src/ggml-openvino/openvino/decoder.h b/ggml/src/ggml-openvino/openvino/decoder.h index ed6ff7c0aba..764a269ec7a 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.h +++ b/ggml/src/ggml-openvino/openvino/decoder.h @@ -35,6 +35,8 @@ class GgmlDecoder : public DecoderBase { virtual element::Type get_output_type(const int node_idx) const = 0; + virtual std::vector get_output_stride(int node_idx) const = 0; + virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0; virtual int32_t * get_output_op_params(int node_idx) const = 0; @@ -69,6 +71,8 @@ class GgmlDecoder : public DecoderBase { virtual bool is_splited_model() const = 0; virtual int is_swa_layer(int layer) const = 0; + + virtual int32_t get_op_dynamic_dim(int node_idx) const = 0; }; } // namespace ggml diff --git a/ggml/src/ggml-openvino/openvino/node_context.h b/ggml/src/ggml-openvino/openvino/node_context.h index aa484128a95..70d6c02e8e1 100644 --- a/ggml/src/ggml-openvino/openvino/node_context.h +++ b/ggml/src/ggml-openvino/openvino/node_context.h @@ -59,12 +59,20 @@ class NodeContext : public frontend::NodeContext { return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]); } + int32_t get_op_dynamic_dim() const { + return m_decoder->get_op_dynamic_dim(m_node_idx); + } + int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); } ov::element::Type get_output_type() const { return m_decoder->get_output_type(m_node_idx); } + std::vector get_output_stride() const { + return m_decoder->get_output_stride(m_node_idx); + } + Output get_input(int idx) const override { return m_tensor_map->at(m_input_names[idx]); } diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp index 6160dd74444..243e236f166 100644 --- a/ggml/src/ggml-openvino/openvino/op/cont.cpp +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -18,27 +18,17 @@ namespace op { OutputVector translate_cont(const NodeContext & context) { num_inputs_check(context, 1, 1); - int op_case = context.get_op_case(); - FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case"); - auto src_shape = context.get_input_shape(0).to_shape(); auto dst_shape = context.get_output_shape().to_shape(); - ov::Output res; - if (op_case == 1) { - // The input comes from a PERMUTE - throw std::runtime_error("Code of this case might be outdated"); - dst_shape[1] = -1; - res = std::make_shared( - context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false); - } else if (op_case == 2) { - // The input comes from a TRANSPOSE - return {context.get_input(0)}; - } else { - // The input comes from a VIEW - res = process_view_input(context, 0); + if (context.get_op_dynamic_dim() != -1) { + dst_shape[3 - context.get_op_dynamic_dim()] = -1; } + ov::Output res; + res = std::make_shared( + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false); + return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp index 8e62e83c0d7..b3b4614e440 100644 --- a/ggml/src/ggml-openvino/openvino/op/transpose.cpp +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -12,8 +12,37 @@ namespace op { OutputVector translate_transpose(const NodeContext & context) { num_inputs_check(context, 1, 1); + // Compute permute order from input/output shape and stride information + // so it adapts to different input and output layouts. + auto input_shape = context.get_input_shape(0).to_shape(); + auto input_stride = context.get_input_stride(0); + auto output_shape = context.get_output_shape().to_shape(); + auto output_stride = context.get_output_stride(); + + // Compute permute order by matching output and input stride rankings. + // Build pairs. + std::vector> output_stride_dims; + std::vector> input_stride_dims; + + for (int i = 0; i < 4; ++i) { + output_stride_dims.push_back({output_stride[i], i}); + input_stride_dims.push_back({input_stride[i], i}); + } + + // Sort by stride in descending order. + std::sort(output_stride_dims.rbegin(), output_stride_dims.rend()); + std::sort(input_stride_dims.rbegin(), input_stride_dims.rend()); + + // Build permute order. + std::vector permute_order(4); + for (int i = 0; i < 4; ++i) { + int output_dim = output_stride_dims[i].second; + int input_dim = input_stride_dims[i].second; + permute_order[output_dim] = input_dim; + } + auto res = std::make_shared( - context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2})); + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, permute_order)); return rename_outputs_with_suffix({res}, context.get_name()); } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 5c832fbe459..e7bd8820df0 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -502,6 +502,9 @@ bool is_model_splitted(ggml_cgraph * cgraph) { if ((cgraph->n_nodes <= 1 && use_count==0) || (cgraph->n_nodes <= 1 && node->op == GGML_OP_VIEW && use_count == 1 && node->src[0] != nullptr && node->src[0]->op == GGML_OP_NONE)) { return false; } + if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_TRANSPOSE || cgraph->nodes[0]->op == GGML_OP_PERMUTE)) { + return false; + } int input_use_count = 0; for (int j = 0; j < cgraph->n_nodes; j++) { ggml_tensor * other_node = cgraph->nodes[j];