diff --git a/onnx/src/model.rs b/onnx/src/model.rs index 668b82b9b4..93d227175c 100644 --- a/onnx/src/model.rs +++ b/onnx/src/model.rs @@ -1,5 +1,5 @@ use std::path::PathBuf; -use std::{fs, path}; +use std::{ fs, path}; use std::collections::HashMap; @@ -8,7 +8,7 @@ use tract_hir::prelude::tract_itertools::Itertools; use crate::data_resolver::{self, ModelDataResolver}; use crate::pb::type_proto::Value; -use crate::pb::{self, TensorProto, TypeProto}; +use crate::pb::{self, GraphProto, OperatorSetIdProto, TensorProto, TypeProto}; use crate::tensor::{load_tensor, translate_inference_fact}; use prost::Message; @@ -213,6 +213,60 @@ impl OnnxOpRegister { } } +#[derive(Debug, Clone)] +pub struct OnnxMetadata { + pub ir_version: i64, + pub opset_import: Vec, + pub producer_name: String, + pub producer_version: String, + pub domain: String, + pub model_version: i64, + pub doc_string: String, + pub graph: Option, + pub metadata_props: HashMap +} + + +impl OnnxMetadata { + pub fn get_metadata(model_proto: &pb::ModelProto ) -> TractResult { + let parse_metadata_props: HashMap = model_proto.to_owned().metadata_props + .into_iter().map(|entry| (entry.key, entry.value)).collect(); + Ok(OnnxMetadata { + ir_version: model_proto.ir_version, + opset_import: model_proto.clone().opset_import, + producer_name: model_proto.clone().producer_name, + producer_version: model_proto.clone().producer_version, + domain: model_proto.clone().domain, + model_version: model_proto.model_version, + doc_string: model_proto.clone().doc_string, + graph: match model_proto.graph { + Some(_) => model_proto.clone().graph , + None => None + }, + metadata_props: parse_metadata_props + }) + } +} + +impl Default for OnnxMetadata { + fn default() -> Self { + let _opset = OperatorSetIdProto::default(); + let graph = GraphProto::default(); + OnnxMetadata { + ir_version: 0, + opset_import: vec![_opset], + producer_name: String::from(""), + producer_version: String::from(""), + domain: String::from(""), + model_version: 0, + doc_string: String::from(""), + graph: Some(graph), + metadata_props: HashMap::new() + } + } +} + + #[derive(Clone)] pub struct Onnx { pub op_register: OnnxOpRegister, @@ -236,6 +290,23 @@ impl Onnx { pub fn parse(&self, proto: &pb::ModelProto, path: Option<&str>) -> TractResult { self.parse_with_template(proto, path, Default::default()) } + + pub fn load_model_with_metadata(&mut self, model_path: impl AsRef) -> TractResult<(InferenceModel, OnnxMetadata)>{ + let mut path = PathBuf::new(); + path.push(&model_path); + let proto = self.proto_model_for_path(&model_path)?; + let mut dir: Option<&str> = None; + if let Some(dir_opt) = path.parent() { + dir = dir_opt.to_str(); + } + let ParseResult { model, unresolved_inputs, .. } = self.parse(&proto, dir)?; + if unresolved_inputs.len() > 0 { + bail!("Could not resolve inputs at top-level: {:?}", unresolved_inputs) + } + let _metadata = OnnxMetadata::get_metadata(&proto)?; + Ok((model, _metadata)) + } + pub fn parse_with_template( &self, proto: &pb::ModelProto, @@ -248,8 +319,8 @@ impl Onnx { .find(|import| import.domain.is_empty() || import.domain == "ai.onnx") .map(|op| op.version) .unwrap_or(0); - let graph = - proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?; + // self.metadata = OnnxMetadata::get_metadata(&proto)?; + let graph = proto.graph.as_ref().ok_or_else(|| anyhow!("model proto does not contain a graph"))?; debug!("ONNX operator set version: {:?}", onnx_operator_set_version); if onnx_operator_set_version != 0 && !(9..19).contains(&onnx_operator_set_version) { warn!("ONNX operator for your model is {}, tract is only tested against \ @@ -267,7 +338,6 @@ impl Onnx { trace!("created ParsingContext"); ctx.parse_graph(graph) } - pub fn with_ignore_output_shapes(self, ignore: bool) -> Onnx { Self { use_output_shapes: !ignore, ..self } }