Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,6 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {
}
break;
}
case GGML_OP_CONT: {
if (node->src[0]->op == GGML_OP_PERMUTE) {
op_case = 1;
} else if (node->src[0]->op == GGML_OP_TRANSPOSE) {
op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW) {
op_case = 3;
}
break;
}
case GGML_OP_PERMUTE: {
if (node->src[0]->op != GGML_OP_VIEW) {
op_case = 1;
Expand All @@ -195,9 +185,7 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {
break;
}
case GGML_OP_MUL_MAT: {
if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) {
op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {
if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {
op_case = 3;
}
break;
Expand Down Expand Up @@ -314,6 +302,14 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
}
break;
}
// if the node op is TRANSPOSE and its input is PERMUTE and the source of the PERMUTE is VIEW, then get the attention size with the TRANSPOSE node ne[0] (in case no GGML_OP_FLASH_ATTN_EXT)
if (node->op == GGML_OP_TRANSPOSE && node->src[0]->op == GGML_OP_PERMUTE &&
node->src[0]->src[0]->op == GGML_OP_VIEW) {
compute_params.attention_size = node->ne[0];
if (is_static) {
compute_params.attention_size = model_params.ctx_per_seq;
}
}
if (node->op == GGML_OP_ROPE) {
memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15);
}
Expand Down Expand Up @@ -880,6 +876,11 @@ ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const {
return get_ov_type(m_node_info_list[node_idx].node);
}

std::vector<size_t> GgmlOvDecoder::get_output_stride(int node_idx) const {
auto * ggml_tensor = m_node_info_list[node_idx].node;
return get_stride(ggml_tensor);
}

std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {
return {m_node_info_list[node_idx].node_output_name};
}
Expand All @@ -889,6 +890,14 @@ const std::string & GgmlOvDecoder::get_op_name() const {
return unknown_name;
}

int32_t GgmlOvDecoder::get_op_dynamic_dim(int node_idx) const {
auto it = m_node_dynamic_dims.find(m_node_info_list[node_idx].node);
if (it == m_node_dynamic_dims.end()) {
return -1;
}
return it->second;
}

const std::string & GgmlOvDecoder::get_op_name(int node_idx) const {
return m_node_info_list[node_idx].node_name;
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-openvino/ggml-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual ov::element::Type get_output_type(int node_idx) const override;

virtual std::vector<size_t> get_output_stride(int node_idx) const override;

virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override;

virtual int32_t * get_output_op_params(int node_idx) const override;
Expand All @@ -121,6 +123,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual const std::string & get_op_name(int node_idx) const override;

virtual int32_t get_op_dynamic_dim(int node_idx) const override;

virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const override;

ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-openvino/ggml-openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,14 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
}
break;
}
case GGML_OP_TRANSPOSE: {
// if the type is bf16, will return true
if (op->type == GGML_TYPE_BF16) {
// GGML_LOG_WARN("OpenVINO backend does not support CONT with BF16 type\n");
return true;
}
break;
}
default:
break;
}
Expand All @@ -933,7 +941,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K};

static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW,
/*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE,
GGML_OP_CONT, GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE,
// softmax is not updated due to replaced by flash_attn_ext
// GGML_OP_SOFT_MAX,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-openvino/openvino/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class GgmlDecoder : public DecoderBase {

virtual element::Type get_output_type(const int node_idx) const = 0;

virtual std::vector<size_t> get_output_stride(int node_idx) const = 0;

virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0;

virtual int32_t * get_output_op_params(int node_idx) const = 0;
Expand Down Expand Up @@ -69,6 +71,8 @@ class GgmlDecoder : public DecoderBase {
virtual bool is_splited_model() const = 0;

virtual int is_swa_layer(int layer) const = 0;

virtual int32_t get_op_dynamic_dim(int node_idx) const = 0;
};

} // namespace ggml
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-openvino/openvino/node_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,20 @@ class NodeContext : public frontend::NodeContext {
return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);
}

int32_t get_op_dynamic_dim() const {
return m_decoder->get_op_dynamic_dim(m_node_idx);
}

int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }

ov::element::Type get_output_type() const {
return m_decoder->get_output_type(m_node_idx);
}

std::vector<size_t> get_output_stride() const {
return m_decoder->get_output_stride(m_node_idx);
}

Output<Node> get_input(int idx) const override {
return m_tensor_map->at(m_input_names[idx]);
}
Expand Down
22 changes: 6 additions & 16 deletions ggml/src/ggml-openvino/openvino/op/cont.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,17 @@ namespace op {
OutputVector translate_cont(const NodeContext & context) {
num_inputs_check(context, 1, 1);

int op_case = context.get_op_case();
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case");

auto src_shape = context.get_input_shape(0).to_shape();
auto dst_shape = context.get_output_shape().to_shape();
ov::Output<Node> res;

if (op_case == 1) {
// The input comes from a PERMUTE
throw std::runtime_error("Code of this case might be outdated");
dst_shape[1] = -1;
res = std::make_shared<ov::op::v1::Reshape>(
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false);
} else if (op_case == 2) {
// The input comes from a TRANSPOSE
return {context.get_input(0)};
} else {
// The input comes from a VIEW
res = process_view_input(context, 0);
if (context.get_op_dynamic_dim() != -1) {
dst_shape[3 - context.get_op_dynamic_dim()] = -1;
}

ov::Output<Node> res;
res = std::make_shared<ov::op::v1::Reshape>(
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false);

return rename_outputs_with_suffix({res}, context.get_name());
}

Expand Down
31 changes: 30 additions & 1 deletion ggml/src/ggml-openvino/openvino/op/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,37 @@ namespace op {
OutputVector translate_transpose(const NodeContext & context) {
num_inputs_check(context, 1, 1);

// Compute permute order from input/output shape and stride information
// so it adapts to different input and output layouts.
auto input_shape = context.get_input_shape(0).to_shape();
auto input_stride = context.get_input_stride(0);
auto output_shape = context.get_output_shape().to_shape();
auto output_stride = context.get_output_stride();

// Compute permute order by matching output and input stride rankings.
// Build <stride, dim_index> pairs.
std::vector<std::pair<size_t, int>> output_stride_dims;
std::vector<std::pair<size_t, int>> input_stride_dims;

for (int i = 0; i < 4; ++i) {
output_stride_dims.push_back({output_stride[i], i});
input_stride_dims.push_back({input_stride[i], i});
}

// Sort by stride in descending order.
std::sort(output_stride_dims.rbegin(), output_stride_dims.rend());
std::sort(input_stride_dims.rbegin(), input_stride_dims.rend());

// Build permute order.
std::vector<int64_t> permute_order(4);
for (int i = 0; i < 4; ++i) {
int output_dim = output_stride_dims[i].second;
int input_dim = input_stride_dims[i].second;
permute_order[output_dim] = input_dim;
}

auto res = std::make_shared<ov::op::v1::Transpose>(
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2}));
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, permute_order));
return rename_outputs_with_suffix({res}, context.get_name());
}

Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,9 @@ bool is_model_splitted(ggml_cgraph * cgraph) {
if ((cgraph->n_nodes <= 1 && use_count==0) || (cgraph->n_nodes <= 1 && node->op == GGML_OP_VIEW && use_count == 1 && node->src[0] != nullptr && node->src[0]->op == GGML_OP_NONE)) {
return false;
}
if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_TRANSPOSE || cgraph->nodes[0]->op == GGML_OP_PERMUTE)) {
return false;
}
int input_use_count = 0;
for (int j = 0; j < cgraph->n_nodes; j++) {
ggml_tensor * other_node = cgraph->nodes[j];
Expand Down