Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 136 additions & 49 deletions onnx/src/ops/fft.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::model::{OnnxOpRegister, ParsingContext};
use crate::model::{optional_inputs, OnnxOpRegister, ParsingContext};
use crate::pb::NodeProto;
use tract_hir::internal::*;
use tract_hir::ops::array::Pad;
Expand All @@ -14,17 +14,21 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) {
reg.insert("HannWindow", window);
}

fn dft(
_ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let axis = node.get_attr_opt("axis")?.unwrap_or(1);
fn dft(ctx: &ParsingContext, node: &NodeProto) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let inverse = node.get_attr_opt("inverse")?.unwrap_or(0i64) != 0;
let onesided = node.get_attr_opt("onesided")?.unwrap_or(0) != 0;
if node.input.len() > 1 {
bail!("length input is not implemented")
let mut optional_inputs = optional_inputs(node).skip(1);
let length_input = optional_inputs.next().unwrap();
let axis_input = optional_inputs.next().unwrap();
if ctx.onnx_operator_set_version < 20 {
let axis = node.get_attr_opt("axis")?.unwrap_or(1);
Ok((
expand(Dft17 { axis, inverse, onesided, has_length_input: node.input.len() == 2 }),
vec![],
))
} else {
Ok((expand(dbg!(Dft { axis_input, inverse, onesided, length_input })), vec![]))
}
Ok((expand(Dft { axis, inverse, onesided, has_length_input: node.input.len() == 2 }), vec![]))
}

fn stft(
Expand Down Expand Up @@ -67,18 +71,16 @@ fn window(
}

#[derive(Clone, Debug, Hash)]
struct Dft {
axis: usize,
struct Dft17 {
axis: i64,
inverse: bool,
onesided: bool,
has_length_input: bool,
}



impl Expansion for Dft {
impl Expansion for Dft17 {
fn name(&self) -> StaticName {
"DFT".into()
"Dft17".into()
}

fn rules<'r, 'p: 'r, 's: 'r>(
Expand All @@ -95,23 +97,14 @@ impl Expansion for Dft {
if self.has_length_input {
s.equals(&inputs[1].rank, 0)?;
}
s.given(&inputs[0].rank, |s, rank| {
for ax in 0..rank as usize - 1 {
if ax != self.axis {
s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])?;
}
}
s.equals(&outputs[0].shape[rank as usize - 1], 2.to_dim())?;
Ok(())
})?;
if self.has_length_input {
s.given(&inputs[1].value[0], |s, len| {
s.equals(len.to_dim(), &outputs[0].shape[self.axis])
})?;
let length_input = if self.has_length_input { Some(1) } else { None };
if self.axis >= 0 {
dft_rules(s, inputs, outputs, 1, length_input)
} else {
s.equals(&inputs[0].shape[self.axis], &outputs[0].shape[self.axis])?;
s.given(&inputs[0].rank, move |s, rank| {
dft_rules(s, inputs, outputs, (self.axis + rank) as usize, length_input)
})
}
Ok(())
}

fn wire(
Expand All @@ -121,7 +114,8 @@ impl Expansion for Dft {
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let fact = model.outlet_fact(inputs[0])?.clone();
let mut wire: TVec<OutletId> = inputs.into();
let axis = if self.axis >= 0 { self.axis } else { self.axis + fact.rank() as i64 } as usize;
let mut wire = tvec!(inputs[0]);
if fact.shape.last() == Some(&1.to_dim()) {
let mut pads = vec![(0, 0); fact.rank() - 1];
pads.push((0, 1));
Expand All @@ -133,20 +127,20 @@ impl Expansion for Dft {
};
wire = model.wire_node(
format!("{prefix}.fft"),
tract_core::ops::fft::Fft { axis: self.axis, inverse: self.inverse },
tract_core::ops::fft::Fft { axis, inverse: self.inverse },
&wire,
)?;
if self.inverse {
let len = model.add_const(
format!("{prefix}.len"),
tensor0(fact.shape[self.axis].clone()).broadcast_into_rank(fact.rank())?,
tensor0(fact.shape[axis].clone()).broadcast_into_rank(fact.rank())?,
)?;
let casted =
model.wire_node(format!("{prefix}.cast"), cast(fact.datum_type), &[len])?;
wire = model.wire_node(format!("{prefix}.norm"), div(), &[wire[0], casted[0]])?;
}
if self.onesided {
let frame = fact.shape[self.axis].clone() / 2 + 1;
let frame = fact.shape[axis].clone() / 2 + 1;
wire = model.wire_node(
format!("{prefix}.onesided"),
tract_core::ops::array::Slice::new(2, 0, frame),
Expand All @@ -157,15 +151,111 @@ impl Expansion for Dft {
}
}

#[derive(Clone, Debug, Hash)]
struct Dft {
inverse: bool,
onesided: bool,
length_input: Option<usize>,
axis_input: Option<usize>,
}

impl Expansion for Dft {
fn name(&self) -> StaticName {
"Dft".into()
}

fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
check_input_arity(
inputs,
1 + self.length_input.is_some() as usize + self.axis_input.is_some() as usize,
)?;
check_output_arity(outputs, 1)?;

s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
s.equals(&inputs[0].rank, &outputs[0].rank)?;
if let Some(len_input) = self.length_input {
s.equals(&inputs[len_input].rank, 0)?;
}
if let Some(axis_input) = self.axis_input {
s.given(&inputs[axis_input].value, move |s, axis| {
let axis = axis.cast_to_scalar::<i64>()?;
if axis >= 0 {
dft_rules(s, inputs, outputs, axis as usize, self.length_input)
} else {
s.given(&inputs[0].rank, move |s, rank| {
dft_rules(s, inputs, outputs, (axis + rank) as usize, self.length_input)
})
}
})
} else {
dft_rules(s, inputs, outputs, 1, self.length_input)
}
}

fn wire(
&self,
prefix: &str,
model: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let axis = if let Some(axis) = self.axis_input {
model
.outlet_fact(inputs[axis])?
.konst
.as_ref()
.and_then(|k| k.cast_to_scalar::<i64>().ok())
.context("Axis input must be a known scalar")?
} else {
1
};
Dft17 {
axis,
inverse: self.inverse,
onesided: self.onesided,
has_length_input: self.length_input.is_some(),
}
.wire(prefix, model, inputs)
}
}

fn dft_rules<'r, 'p: 'r>(
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
axis: usize,
length_input: Option<usize>,
) -> InferenceResult {
s.given(&inputs[0].rank, move |s, rank| {
for ax in 0..rank as usize - 1 {
if ax != axis {
s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])?;
}
}
s.equals(&outputs[0].shape[rank as usize - 1], 2.to_dim())?;
Ok(())
})?;
if let Some(len_input) = length_input {
s.given(&inputs[1].value[len_input], move |s, len| {
s.equals(len.to_dim(), &outputs[0].shape[axis])
})?;
} else {
s.equals(&inputs[0].shape[axis], &outputs[0].shape[axis])?;
}
Ok(())
}

#[derive(Clone, Debug, Hash)]
struct Stft {
onesided: bool,
optional_window_input: Option<usize>,
optional_frame_length_input: Option<usize>,
}



impl Expansion for Stft {
fn name(&self) -> StaticName {
"STFT".into()
Expand Down Expand Up @@ -288,8 +378,6 @@ pub struct MelWeightMatrix {
datum_type: DatumType,
}



impl Expansion for MelWeightMatrix {
fn name(&self) -> StaticName {
"MelWeightMatrix".into()
Expand Down Expand Up @@ -329,15 +417,16 @@ impl Expansion for MelWeightMatrix {
Some(sample_rate),
Some(lower_edge_hertz),
Some(upper_edge_hertz),
) = (
model.outlet_fact(inputs[0])?.konst.as_ref(),
model.outlet_fact(inputs[1])?.konst.as_ref(),
model.outlet_fact(inputs[2])?.konst.as_ref(),
model.outlet_fact(inputs[3])?.konst.as_ref(),
model.outlet_fact(inputs[4])?.konst.as_ref(),
) else {
bail!("Expect all inputs to be constants")
};
) = (
model.outlet_fact(inputs[0])?.konst.as_ref(),
model.outlet_fact(inputs[1])?.konst.as_ref(),
model.outlet_fact(inputs[2])?.konst.as_ref(),
model.outlet_fact(inputs[3])?.konst.as_ref(),
model.outlet_fact(inputs[4])?.konst.as_ref(),
)
else {
bail!("Expect all inputs to be constants")
};
let num_mel_bins = num_mel_bins.cast_to_scalar::<i64>()? as usize;
let dft_length = dft_length.cast_to_scalar::<i64>()? as usize;
let sample_rate = sample_rate.cast_to_scalar::<i64>()? as usize;
Expand Down Expand Up @@ -423,8 +512,6 @@ pub struct StftWindow {
window: StftWindowType,
}



impl Expansion for StftWindow {
fn name(&self) -> StaticName {
format!("StftWindow<{:?}>", self.window).into()
Expand Down
4 changes: 1 addition & 3 deletions test-rt/suite-onnx/node.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ test_cumsum_2d_negative_axis input:x since:13
test_cumsum_2d not-nnef input:x since:13
test_depthtospace.*
test_dequantizelinear input:x not-nnef
# test_dft
# test_dft_axis
# test_dft_inverse
test_dft.* input:x
test_div.*
test_dropout_default
test_dropout_default_old
Expand Down
2 changes: 1 addition & 1 deletion test-rt/test-onnx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ onnx_1_16_2 = ["suite-onnx/onnx_1_16_2"]
onnx_1_17_0 = ["suite-onnx/onnx_1_17_0"]
onnx_1_18_0 = ["suite-onnx/onnx_1_18_0"]
onnx_1_19_1 = ["suite-onnx/onnx_1_19_1"]
default = [ "onnx_1_13_0" ]
default = [ "onnx_1_19_1" ]

[build-dependencies]
suite-onnx = { path = "../suite-onnx" }
Loading