From 156d15624cf0851d1b7d17ee77baa19be2f2a3e1 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 11 Nov 2025 09:55:40 +0100 Subject: [PATCH 1/2] fixes for new dft syntax --- onnx/src/ops/fft.rs | 167 ++++++++++++++++++++++++++++++++------------ 1 file changed, 121 insertions(+), 46 deletions(-) diff --git a/onnx/src/ops/fft.rs b/onnx/src/ops/fft.rs index 674bef5034..eaecac0208 100644 --- a/onnx/src/ops/fft.rs +++ b/onnx/src/ops/fft.rs @@ -1,4 +1,4 @@ -use crate::model::{OnnxOpRegister, ParsingContext}; +use crate::model::{optional_inputs, OnnxOpRegister, ParsingContext}; use crate::pb::NodeProto; use tract_hir::internal::*; use tract_hir::ops::array::Pad; @@ -14,17 +14,21 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { reg.insert("HannWindow", window); } -fn dft( - _ctx: &ParsingContext, - node: &NodeProto, -) -> TractResult<(Box, Vec)> { - let axis = node.get_attr_opt("axis")?.unwrap_or(1); +fn dft(ctx: &ParsingContext, node: &NodeProto) -> TractResult<(Box, Vec)> { let inverse = node.get_attr_opt("inverse")?.unwrap_or(0i64) != 0; let onesided = node.get_attr_opt("onesided")?.unwrap_or(0) != 0; - if node.input.len() > 1 { - bail!("length input is not implemented") + let mut optional_inputs = optional_inputs(node).skip(1); + let length_input = optional_inputs.next().unwrap(); + let axis_input = optional_inputs.next().unwrap(); + if ctx.onnx_operator_set_version < 20 { + let axis = node.get_attr_opt("axis")?.unwrap_or(1); + Ok(( + expand(Dft17 { axis, inverse, onesided, has_length_input: node.input.len() == 2 }), + vec![], + )) + } else { + Ok((expand(dbg!(Dft { axis_input, inverse, onesided, length_input })), vec![])) } - Ok((expand(Dft { axis, inverse, onesided, has_length_input: node.input.len() == 2 }), vec![])) } fn stft( @@ -67,18 +71,16 @@ fn window( } #[derive(Clone, Debug, Hash)] -struct Dft { +struct Dft17 { axis: usize, inverse: bool, onesided: bool, has_length_input: bool, } - - -impl Expansion for Dft { +impl Expansion for Dft17 { fn name(&self) -> StaticName { - "DFT".into() + "DFT17".into() } fn rules<'r, 'p: 'r, 's: 'r>( @@ -95,22 +97,7 @@ impl Expansion for Dft { if self.has_length_input { s.equals(&inputs[1].rank, 0)?; } - s.given(&inputs[0].rank, |s, rank| { - for ax in 0..rank as usize - 1 { - if ax != self.axis { - s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])?; - } - } - s.equals(&outputs[0].shape[rank as usize - 1], 2.to_dim())?; - Ok(()) - })?; - if self.has_length_input { - s.given(&inputs[1].value[0], |s, len| { - s.equals(len.to_dim(), &outputs[0].shape[self.axis]) - })?; - } else { - s.equals(&inputs[0].shape[self.axis], &outputs[0].shape[self.axis])?; - } + dft_rules(s, inputs, outputs, 1, if self.has_length_input { Some(1) } else { None })?; Ok(()) } @@ -121,7 +108,7 @@ impl Expansion for Dft { inputs: &[OutletId], ) -> TractResult> { let fact = model.outlet_fact(inputs[0])?.clone(); - let mut wire: TVec = inputs.into(); + let mut wire = tvec!(inputs[0]); if fact.shape.last() == Some(&1.to_dim()) { let mut pads = vec![(0, 0); fact.rank() - 1]; pads.push((0, 1)); @@ -157,6 +144,99 @@ impl Expansion for Dft { } } +#[derive(Clone, Debug, Hash)] +struct Dft { + inverse: bool, + onesided: bool, + length_input: Option, + axis_input: Option, +} + +impl Expansion for Dft { + fn name(&self) -> StaticName { + "DFT".into() + } + + fn rules<'r, 'p: 'r, 's: 'r>( + &'s self, + s: &mut Solver<'r>, + inputs: &'p [TensorProxy], + outputs: &'p [TensorProxy], + ) -> InferenceResult { + check_input_arity( + inputs, + 1 + self.length_input.is_some() as usize + self.axis_input.is_some() as usize, + )?; + check_output_arity(outputs, 1)?; + + s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?; + s.equals(&inputs[0].rank, &outputs[0].rank)?; + if let Some(len_input) = self.length_input { + s.equals(&inputs[len_input].rank, 0)?; + } + if let Some(axis_input) = self.axis_input { + s.given(&inputs[axis_input].value, |s, axis| { + let axis = axis.cast_to_scalar::()? as usize; + dft_rules(s, inputs, outputs, axis, self.length_input) + })?; + } else { + dft_rules(s, inputs, outputs, 1, self.length_input)?; + } + Ok(()) + } + + fn wire( + &self, + prefix: &str, + model: &mut TypedModel, + inputs: &[OutletId], + ) -> TractResult> { + let axis = if let Some(axis) = self.axis_input { + model + .outlet_fact(inputs[axis])? + .konst + .as_ref() + .and_then(|k| k.cast_to_scalar::().ok()) + .context("Axis input must be a known scalar")? as usize + } else { + 1 + }; + Dft17 { + axis, + inverse: self.inverse, + onesided: self.onesided, + has_length_input: self.length_input.is_some(), + } + .wire(prefix, model, inputs) + } +} + +fn dft_rules<'r, 'p: 'r>( + s: &mut Solver<'r>, + inputs: &'p [TensorProxy], + outputs: &'p [TensorProxy], + axis: usize, + length_input: Option, +) -> InferenceResult { + s.given(&inputs[0].rank, move |s, rank| { + for ax in 0..rank as usize - 1 { + if ax != axis { + s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])?; + } + } + s.equals(&outputs[0].shape[rank as usize - 1], 2.to_dim())?; + Ok(()) + })?; + if let Some(len_input) = length_input { + s.given(&inputs[1].value[len_input], move |s, len| { + s.equals(len.to_dim(), &outputs[0].shape[axis]) + })?; + } else { + s.equals(&inputs[0].shape[axis], &outputs[0].shape[axis])?; + } + Ok(()) +} + #[derive(Clone, Debug, Hash)] struct Stft { onesided: bool, @@ -164,8 +244,6 @@ struct Stft { optional_frame_length_input: Option, } - - impl Expansion for Stft { fn name(&self) -> StaticName { "STFT".into() @@ -288,8 +366,6 @@ pub struct MelWeightMatrix { datum_type: DatumType, } - - impl Expansion for MelWeightMatrix { fn name(&self) -> StaticName { "MelWeightMatrix".into() @@ -329,15 +405,16 @@ impl Expansion for MelWeightMatrix { Some(sample_rate), Some(lower_edge_hertz), Some(upper_edge_hertz), - ) = ( - model.outlet_fact(inputs[0])?.konst.as_ref(), - model.outlet_fact(inputs[1])?.konst.as_ref(), - model.outlet_fact(inputs[2])?.konst.as_ref(), - model.outlet_fact(inputs[3])?.konst.as_ref(), - model.outlet_fact(inputs[4])?.konst.as_ref(), - ) else { - bail!("Expect all inputs to be constants") - }; + ) = ( + model.outlet_fact(inputs[0])?.konst.as_ref(), + model.outlet_fact(inputs[1])?.konst.as_ref(), + model.outlet_fact(inputs[2])?.konst.as_ref(), + model.outlet_fact(inputs[3])?.konst.as_ref(), + model.outlet_fact(inputs[4])?.konst.as_ref(), + ) + else { + bail!("Expect all inputs to be constants") + }; let num_mel_bins = num_mel_bins.cast_to_scalar::()? as usize; let dft_length = dft_length.cast_to_scalar::()? as usize; let sample_rate = sample_rate.cast_to_scalar::()? as usize; @@ -423,8 +500,6 @@ pub struct StftWindow { window: StftWindowType, } - - impl Expansion for StftWindow { fn name(&self) -> StaticName { format!("StftWindow<{:?}>", self.window).into() From 109fa18fcb83d18bca556c397646efa73593232c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 11 Nov 2025 09:59:58 +0100 Subject: [PATCH 2/2] make both dft variants work --- onnx/src/ops/fft.rs | 42 ++++++++++++++++++++----------- test-rt/suite-onnx/node.txt | 4 +-- test-rt/test-onnx-core/Cargo.toml | 2 +- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/onnx/src/ops/fft.rs b/onnx/src/ops/fft.rs index eaecac0208..6ff6ea0a75 100644 --- a/onnx/src/ops/fft.rs +++ b/onnx/src/ops/fft.rs @@ -72,7 +72,7 @@ fn window( #[derive(Clone, Debug, Hash)] struct Dft17 { - axis: usize, + axis: i64, inverse: bool, onesided: bool, has_length_input: bool, @@ -80,7 +80,7 @@ struct Dft17 { impl Expansion for Dft17 { fn name(&self) -> StaticName { - "DFT17".into() + "Dft17".into() } fn rules<'r, 'p: 'r, 's: 'r>( @@ -97,8 +97,14 @@ impl Expansion for Dft17 { if self.has_length_input { s.equals(&inputs[1].rank, 0)?; } - dft_rules(s, inputs, outputs, 1, if self.has_length_input { Some(1) } else { None })?; - Ok(()) + let length_input = if self.has_length_input { Some(1) } else { None }; + if self.axis >= 0 { + dft_rules(s, inputs, outputs, 1, length_input) + } else { + s.given(&inputs[0].rank, move |s, rank| { + dft_rules(s, inputs, outputs, (self.axis + rank) as usize, length_input) + }) + } } fn wire( @@ -108,6 +114,7 @@ impl Expansion for Dft17 { inputs: &[OutletId], ) -> TractResult> { let fact = model.outlet_fact(inputs[0])?.clone(); + let axis = if self.axis >= 0 { self.axis } else { self.axis + fact.rank() as i64 } as usize; let mut wire = tvec!(inputs[0]); if fact.shape.last() == Some(&1.to_dim()) { let mut pads = vec![(0, 0); fact.rank() - 1]; @@ -120,20 +127,20 @@ impl Expansion for Dft17 { }; wire = model.wire_node( format!("{prefix}.fft"), - tract_core::ops::fft::Fft { axis: self.axis, inverse: self.inverse }, + tract_core::ops::fft::Fft { axis, inverse: self.inverse }, &wire, )?; if self.inverse { let len = model.add_const( format!("{prefix}.len"), - tensor0(fact.shape[self.axis].clone()).broadcast_into_rank(fact.rank())?, + tensor0(fact.shape[axis].clone()).broadcast_into_rank(fact.rank())?, )?; let casted = model.wire_node(format!("{prefix}.cast"), cast(fact.datum_type), &[len])?; wire = model.wire_node(format!("{prefix}.norm"), div(), &[wire[0], casted[0]])?; } if self.onesided { - let frame = fact.shape[self.axis].clone() / 2 + 1; + let frame = fact.shape[axis].clone() / 2 + 1; wire = model.wire_node( format!("{prefix}.onesided"), tract_core::ops::array::Slice::new(2, 0, frame), @@ -154,7 +161,7 @@ struct Dft { impl Expansion for Dft { fn name(&self) -> StaticName { - "DFT".into() + "Dft".into() } fn rules<'r, 'p: 'r, 's: 'r>( @@ -175,14 +182,19 @@ impl Expansion for Dft { s.equals(&inputs[len_input].rank, 0)?; } if let Some(axis_input) = self.axis_input { - s.given(&inputs[axis_input].value, |s, axis| { - let axis = axis.cast_to_scalar::()? as usize; - dft_rules(s, inputs, outputs, axis, self.length_input) - })?; + s.given(&inputs[axis_input].value, move |s, axis| { + let axis = axis.cast_to_scalar::()?; + if axis >= 0 { + dft_rules(s, inputs, outputs, axis as usize, self.length_input) + } else { + s.given(&inputs[0].rank, move |s, rank| { + dft_rules(s, inputs, outputs, (axis + rank) as usize, self.length_input) + }) + } + }) } else { - dft_rules(s, inputs, outputs, 1, self.length_input)?; + dft_rules(s, inputs, outputs, 1, self.length_input) } - Ok(()) } fn wire( @@ -197,7 +209,7 @@ impl Expansion for Dft { .konst .as_ref() .and_then(|k| k.cast_to_scalar::().ok()) - .context("Axis input must be a known scalar")? as usize + .context("Axis input must be a known scalar")? } else { 1 }; diff --git a/test-rt/suite-onnx/node.txt b/test-rt/suite-onnx/node.txt index af625195ab..8f420016b4 100644 --- a/test-rt/suite-onnx/node.txt +++ b/test-rt/suite-onnx/node.txt @@ -134,9 +134,7 @@ test_cumsum_2d_negative_axis input:x since:13 test_cumsum_2d not-nnef input:x since:13 test_depthtospace.* test_dequantizelinear input:x not-nnef -# test_dft -# test_dft_axis -# test_dft_inverse +test_dft.* input:x test_div.* test_dropout_default test_dropout_default_old diff --git a/test-rt/test-onnx-core/Cargo.toml b/test-rt/test-onnx-core/Cargo.toml index 5185c079c1..458d938336 100644 --- a/test-rt/test-onnx-core/Cargo.toml +++ b/test-rt/test-onnx-core/Cargo.toml @@ -29,7 +29,7 @@ onnx_1_16_2 = ["suite-onnx/onnx_1_16_2"] onnx_1_17_0 = ["suite-onnx/onnx_1_17_0"] onnx_1_18_0 = ["suite-onnx/onnx_1_18_0"] onnx_1_19_1 = ["suite-onnx/onnx_1_19_1"] -default = [ "onnx_1_13_0" ] +default = [ "onnx_1_19_1" ] [build-dependencies] suite-onnx = { path = "../suite-onnx" }