From 5c62f57db441a8c90d6cdeed3df1e152797930c4 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 2 Mar 2026 15:56:19 +0000 Subject: [PATCH 01/20] nemo-parakeet-asr: extract TdtModel struct with load and transcribe Moves model loading and inference logic out of main into a TdtModel struct. The load constructor takes impl AsRef for the model directory, and transcribe encapsulates the full preprocessing/encoding/ TDT decoding loop. --- examples/nemo-parakeet-asr/src/main.rs | 133 +++++++++++++++---------- 1 file changed, 82 insertions(+), 51 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 82ed8a55c4..4ea43a3797 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -1,79 +1,110 @@ use std::fs::File; +use std::path::Path; use anyhow::*; use float_ord::FloatOrd; use itertools::Itertools; use tract_rs::prelude::tract_ndarray::prelude::*; use tract_rs::prelude::*; +use tract_rs::Nnef; fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } -fn main() -> anyhow::Result<()> { - let config: serde_json::Value = - serde_json::from_reader(File::open("assets/model/model_config.json")?)?; - let blank_id = config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; - let vocab = config.pointer("/joint/vocabulary").unwrap().as_array().unwrap(); - let vocab: Vec<&str> = vocab.iter().map(|v| v.as_str().unwrap()).collect(); +struct TdtModel { + preprocessor: Runnable, + encoder: Runnable, + decoder: Runnable, + joint: Runnable, + vocab: Vec, + blank_id: usize, +} + +impl TdtModel { + fn load(model_dir: impl AsRef, nnef: &Nnef, gpu: &Runtime) -> Result { + let model_dir = model_dir.as_ref(); + let config: serde_json::Value = + serde_json::from_reader(File::open(model_dir.join("model_config.json"))?)?; + let blank_id = + config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; + let vocab = config.pointer("/joint/vocabulary").unwrap().as_array().unwrap(); + let vocab: Vec = vocab.iter().map(|v| v.as_str().unwrap().to_owned()).collect(); + + let preprocessor = + nnef.load(model_dir.join("preprocessor.nnef.tgz"))?.into_runnable()?; + + let mut encoder = nnef.load(model_dir.join("encoder.nnef.tgz"))?; + encoder.transform("transformers-detect-all")?; + let encoder = gpu.prepare(encoder)?; + + let decoder = nnef.load(model_dir.join("decoder.nnef.tgz"))?; + let decoder = gpu.prepare(decoder)?; + + let joint = nnef.load(model_dir.join("joint.nnef.tgz"))?; + let joint = gpu.prepare(joint)?; + + Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) + } + + fn transcribe(&self, wav: &[f32]) -> Result { + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let [features, feat_len] = + self.preprocessor.run([samples, len])?.try_into().unwrap(); + let [encoded, _lens] = + self.encoder.run([features, feat_len])?.try_into().unwrap(); + + let encoded: ArrayD = encoded.view()?.into_owned(); + + let max_frames = encoded.shape()[2]; + let max_len = max_frames * 6 + 10; + let mut hyp = vec![]; + let mut frame_ix = 0; + let mut token = Value::from_slice(&[1, 1], &[0i32])?; + let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + + [token, state_0, state_1] = + self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + while hyp.len() < max_len && frame_ix < max_frames { + let frame: Value = + encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; + let [logits] = self.joint.run([frame, token.clone()])?.try_into().unwrap(); + let logits = logits.view::()?; + let logits = logits.as_slice().unwrap(); + let token_id = argmax(&logits[0..self.blank_id + 1]).unwrap(); + if token_id == self.blank_id { + frame_ix += argmax(&logits[self.blank_id + 1..]).unwrap_or(0).max(1); + } else { + hyp.push(token_id); + token = Value::from_slice(&[1, 1], &[token_id as i32])?; + [token, state_0, state_1] = + self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + } + } + + Ok(hyp.into_iter().map(|t| self.vocab[t].as_str()).join("")) + } +} + +fn main() -> anyhow::Result<()> { let nnef = tract_rs::nnef()?.with_tract_core()?.with_tract_transformers()?; let gpu = ["cuda", "metal", "default"] .iter() .find_map(|rt| tract_rs::runtime_for_name(rt).ok()) .unwrap(); - let preprocessor = nnef.load("assets/model/preprocessor.nnef.tgz")?.into_runnable()?; - - let mut encoder = nnef.load("assets/model/encoder.nnef.tgz")?; - encoder.transform("transformers-detect-all")?; - let encoder = gpu.prepare(encoder)?; - - let decoder = nnef.load("assets/model/decoder.nnef.tgz")?; - let decoder = gpu.prepare(decoder)?; - - let joint = nnef.load("assets/model/joint.nnef.tgz")?; - let joint = gpu.prepare(joint)?; + let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; let wav: Vec = hound::WavReader::open("assets/2086-149220-0033.wav")? .samples::() .map(|x| x.unwrap() as f32) .collect(); - let samples: Value = Value::from_slice(&[1, wav.len()], &wav)?; - let len: Value = arr1(&[wav.len() as i64]).try_into()?; - - let [features, feat_len] = preprocessor.run([samples, len])?.try_into().unwrap(); - let [encoded, _lens] = encoder.run([features, feat_len])?.try_into().unwrap(); - - let encoded: ArrayD = encoded.view()?.into_owned(); - - let max_frames = encoded.shape()[2]; - let max_len = max_frames * 6 + 10; - - let mut hyp = vec![]; - let mut frame_ix = 0; - let mut token = Value::from_slice(&[1, 1], &[0i32])?; - let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; - let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; - - [token, state_0, state_1] = decoder.run([token, state_0, state_1])?.try_into().unwrap(); - while hyp.len() < max_len && frame_ix < max_frames { - let frame: Value = - encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; - let [logits] = joint.run([frame, token.clone()])?.try_into().unwrap(); - let logits = logits.view::()?; - let logits = logits.as_slice().unwrap(); - let token_id = argmax(&logits[0..blank_id + 1]).unwrap(); - if token_id == blank_id { - frame_ix += argmax(&logits[blank_id + 1..]).unwrap_or(0).max(1); - } else { - hyp.push(token_id); - token = Value::from_slice(&[1, 1], &[token_id as i32])?; - [token, state_0, state_1] = decoder.run([token, state_0, state_1])?.try_into().unwrap(); - } - } - let transcript = hyp.into_iter().map(|t| vocab[t]).join(""); + let transcript = model.transcribe(&wav)?; println!("Transcript: {transcript}"); assert_eq!( transcript, From 1a954c474cef90114a3be5bdef67b839f19c6134 Mon Sep 17 00:00:00 2001 From: User Date: Mon, 2 Mar 2026 16:21:49 +0000 Subject: [PATCH 02/20] nemo-parakeet-asr: add transcribe_beam alongside transcribe_greedy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement frame-synchronous beam search (BEAM_SIZE=4, DUR_BEAM_K=2) following NeMo's BeamTDTInfer approach. Add log_softmax helper. Rename transcribe → transcribe_greedy. Wire main to use transcribe_beam. --- examples/nemo-parakeet-asr/src/main.rs | 148 ++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 2 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 4ea43a3797..854b3d1254 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -12,6 +12,12 @@ fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } +fn log_softmax(xs: &[f32]) -> Vec { + let max = xs.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let lse = xs.iter().map(|&x| (x - max).exp()).sum::().ln(); + xs.iter().map(|&x| x - max - lse).collect() +} + struct TdtModel { preprocessor: Runnable, encoder: Runnable, @@ -47,7 +53,7 @@ impl TdtModel { Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) } - fn transcribe(&self, wav: &[f32]) -> Result { + fn transcribe_greedy(&self, wav: &[f32]) -> Result { let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; let len: Value = arr1(&[wav.len() as i64]).try_into()?; @@ -88,6 +94,144 @@ impl TdtModel { Ok(hyp.into_iter().map(|t| self.vocab[t].as_str()).join("")) } + + fn transcribe_beam(&self, wav: &[f32]) -> Result { + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let [features, feat_len] = + self.preprocessor.run([samples, len])?.try_into().unwrap(); + let [encoded, _lens] = + self.encoder.run([features, feat_len])?.try_into().unwrap(); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let max_frames = encoded.shape()[2]; + + const BEAM_SIZE: usize = 4; + const DUR_BEAM_K: usize = 2; + + struct Beam { + score: f32, + tokens: Vec, + last_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, + } + + let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let [dec_out, state_0, state_1] = + self.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + + let mut all_beams: Vec = vec![Beam { + score: 0.0, + tokens: vec![], + last_frame: 0, + dec_out, + state_0, + state_1, + }]; + + for frame_ix in 0..max_frames { + let mut hyps: Vec = Vec::new(); + let mut kept: Vec = Vec::new(); + for b in all_beams.drain(..) { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + + let frame: Value = + encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; + + while !hyps.is_empty() { + let best_idx = hyps + .iter() + .enumerate() + .max_by_key(|(_, b)| FloatOrd(b.score)) + .map(|(i, _)| i) + .unwrap(); + let max_hyp = hyps.remove(best_idx); + + let [logits] = self + .joint + .run([frame.clone(), max_hyp.dec_out.clone()])? + .try_into() + .unwrap(); + let logits = logits.view::()?; + let logits = logits.as_slice().unwrap(); + + let log_probs = log_softmax(&logits[0..=self.blank_id]); + let dur_log_probs = log_softmax(&logits[self.blank_id + 1..]); + + // Non-blank expansions: top BEAM_SIZE tokens, duration fixed at 0 + let mut token_scores: Vec<(usize, f32)> = + (0..self.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + token_scores.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + token_scores.truncate(BEAM_SIZE); + + for (ti, lp) in token_scores { + let new_token = Value::from_slice(&[1, 1], &[ti as i32])?; + let [new_dec_out, new_s0, new_s1] = self + .decoder + .run([new_token, max_hyp.state_0.clone(), max_hyp.state_1.clone()])? + .try_into() + .unwrap(); + let mut new_tokens = max_hyp.tokens.clone(); + new_tokens.push(ti); + hyps.push(Beam { + score: max_hyp.score + lp, + tokens: new_tokens, + last_frame: frame_ix, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + } + + // Blank expansions: top DUR_BEAM_K non-zero durations + let mut dur_scores: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + dur_scores.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + dur_scores.truncate(DUR_BEAM_K); + + for (di, dlp) in dur_scores { + kept.push(Beam { + score: max_hyp.score + log_probs[self.blank_id] + dlp, + tokens: max_hyp.tokens.clone(), + last_frame: frame_ix + di, + dec_out: max_hyp.dec_out.clone(), + state_0: max_hyp.state_0.clone(), + state_1: max_hyp.state_1.clone(), + }); + } + + // Prune combined pool to BEAM_SIZE + let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); + all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + all.truncate(BEAM_SIZE); + for b in all { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + } + + all_beams = kept; + } + + let best = all_beams + .into_iter() + .max_by_key(|b| FloatOrd(b.score)) + .ok_or_else(|| anyhow!("no beams survived"))?; + Ok(best.tokens.into_iter().map(|t| self.vocab[t].as_str()).join("")) + } } fn main() -> anyhow::Result<()> { @@ -104,7 +248,7 @@ fn main() -> anyhow::Result<()> { .map(|x| x.unwrap() as f32) .collect(); - let transcript = model.transcribe(&wav)?; + let transcript = model.transcribe_beam(&wav)?; println!("Transcript: {transcript}"); assert_eq!( transcript, From 19e923e84a78852d72fa498de522e17fe80a1f78 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 03/20] nemo-parakeet-asr: add per-query decoder/joint call instrumentation Track call count, average batch size, and avg/total duration for the decoder and joint networks via a CallStats struct with a custom Debug impl. Both transcribe_greedy and transcribe_beam return stats alongside the transcript. main runs each decoder twice and prints stats only for the second (post-warmup) run. --- examples/nemo-parakeet-asr/src/main.rs | 80 +++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 7 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 854b3d1254..c73bc5cf05 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -1,5 +1,6 @@ use std::fs::File; use std::path::Path; +use std::time::Instant; use anyhow::*; use float_ord::FloatOrd; @@ -8,6 +9,36 @@ use tract_rs::prelude::tract_ndarray::prelude::*; use tract_rs::prelude::*; use tract_rs::Nnef; +#[derive(Default)] +struct CallStats { + calls: u32, + total_batch: u64, + total_us: u64, +} + +impl CallStats { + fn record(&mut self, batch: usize, elapsed: std::time::Duration) { + self.calls += 1; + self.total_batch += batch as u64; + self.total_us += elapsed.as_micros() as u64; + } +} + +impl std::fmt::Debug for CallStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let avg_batch = + if self.calls == 0 { 0.0 } else { self.total_batch as f64 / self.calls as f64 }; + let avg_ms = + if self.calls == 0 { 0.0 } else { self.total_us as f64 / self.calls as f64 / 1000.0 }; + let total_ms = self.total_us as f64 / 1000.0; + write!( + f, + "calls={:5} avg_batch={avg_batch:.1} avg={avg_ms:.3}ms total={total_ms:.1}ms", + self.calls + ) + } +} + fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } @@ -53,7 +84,7 @@ impl TdtModel { Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) } - fn transcribe_greedy(&self, wav: &[f32]) -> Result { + fn transcribe_greedy(&self, wav: &[f32]) -> Result<(String, CallStats, CallStats)> { let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; let len: Value = arr1(&[wav.len() as i64]).try_into()?; @@ -63,22 +94,31 @@ impl TdtModel { self.encoder.run([features, feat_len])?.try_into().unwrap(); let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; let max_frames = encoded.shape()[2]; let max_len = max_frames * 6 + 10; + let mut decoder_stats = CallStats::default(); + let mut joint_stats = CallStats::default(); + let mut hyp = vec![]; let mut frame_ix = 0; let mut token = Value::from_slice(&[1, 1], &[0i32])?; let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); [token, state_0, state_1] = self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + decoder_stats.record(batch, t.elapsed()); + while hyp.len() < max_len && frame_ix < max_frames { let frame: Value = encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; + let t = Instant::now(); let [logits] = self.joint.run([frame, token.clone()])?.try_into().unwrap(); + joint_stats.record(batch, t.elapsed()); let logits = logits.view::()?; let logits = logits.as_slice().unwrap(); let token_id = argmax(&logits[0..self.blank_id + 1]).unwrap(); @@ -87,15 +127,17 @@ impl TdtModel { } else { hyp.push(token_id); token = Value::from_slice(&[1, 1], &[token_id as i32])?; + let t = Instant::now(); [token, state_0, state_1] = self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + decoder_stats.record(batch, t.elapsed()); } } - Ok(hyp.into_iter().map(|t| self.vocab[t].as_str()).join("")) + Ok((hyp.into_iter().map(|t| self.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) } - fn transcribe_beam(&self, wav: &[f32]) -> Result { + fn transcribe_beam(&self, wav: &[f32]) -> Result<(String, CallStats, CallStats)> { let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; let len: Value = arr1(&[wav.len() as i64]).try_into()?; @@ -105,8 +147,12 @@ impl TdtModel { self.encoder.run([features, feat_len])?.try_into().unwrap(); let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; let max_frames = encoded.shape()[2]; + let mut decoder_stats = CallStats::default(); + let mut joint_stats = CallStats::default(); + const BEAM_SIZE: usize = 4; const DUR_BEAM_K: usize = 2; @@ -122,8 +168,10 @@ impl TdtModel { let init_token = Value::from_slice(&[1, 1], &[0i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); let [dec_out, state_0, state_1] = self.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + decoder_stats.record(batch, t.elapsed()); let mut all_beams: Vec = vec![Beam { score: 0.0, @@ -157,11 +205,13 @@ impl TdtModel { .unwrap(); let max_hyp = hyps.remove(best_idx); + let t = Instant::now(); let [logits] = self .joint .run([frame.clone(), max_hyp.dec_out.clone()])? .try_into() .unwrap(); + joint_stats.record(batch, t.elapsed()); let logits = logits.view::()?; let logits = logits.as_slice().unwrap(); @@ -176,11 +226,13 @@ impl TdtModel { for (ti, lp) in token_scores { let new_token = Value::from_slice(&[1, 1], &[ti as i32])?; + let t = Instant::now(); let [new_dec_out, new_s0, new_s1] = self .decoder .run([new_token, max_hyp.state_0.clone(), max_hyp.state_1.clone()])? .try_into() .unwrap(); + decoder_stats.record(batch, t.elapsed()); let mut new_tokens = max_hyp.tokens.clone(); new_tokens.push(ti); hyps.push(Beam { @@ -230,7 +282,7 @@ impl TdtModel { .into_iter() .max_by_key(|b| FloatOrd(b.score)) .ok_or_else(|| anyhow!("no beams survived"))?; - Ok(best.tokens.into_iter().map(|t| self.vocab[t].as_str()).join("")) + Ok((best.tokens.into_iter().map(|t| self.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) } } @@ -248,10 +300,24 @@ fn main() -> anyhow::Result<()> { .map(|x| x.unwrap() as f32) .collect(); - let transcript = model.transcribe_beam(&wav)?; - println!("Transcript: {transcript}"); + model.transcribe_greedy(&wav)?; + let (transcript_g, dec, joint) = model.transcribe_greedy(&wav)?; + eprintln!("[greedy][decoder] {dec:?}"); + eprintln!("[greedy][joint] {joint:?}"); + + model.transcribe_beam(&wav)?; + let (transcript_b, dec, joint) = model.transcribe_beam(&wav)?; + eprintln!("[beam][decoder] {dec:?}"); + eprintln!("[beam][joint] {joint:?}"); + + println!("Greedy: {transcript_g}"); + println!("Beam: {transcript_b}"); + assert_eq!( + transcript_g, + "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." + ); assert_eq!( - transcript, + transcript_b, "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." ); Ok(()) From d297103966e5757871b8d123bed9f645d5bd12ad Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 04/20] nemo-parakeet-asr: batch decoder calls in transcribe_beam Replace the per-token decoder loop with a single batched call, building [n,1] token and [2,n,640] state tensors, then slicing the [n,hidden] / [2,n,640] outputs back into individual beams. Reduces kernel-launch overhead and enables GPU parallelism; avg_batch now reflects BEAM_SIZE. --- examples/nemo-parakeet-asr/src/main.rs | 45 ++++++++++++++++++++------ 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index c73bc5cf05..95663355d0 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -224,15 +224,42 @@ impl TdtModel { token_scores.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); token_scores.truncate(BEAM_SIZE); - for (ti, lp) in token_scores { - let new_token = Value::from_slice(&[1, 1], &[ti as i32])?; - let t = Instant::now(); - let [new_dec_out, new_s0, new_s1] = self - .decoder - .run([new_token, max_hyp.state_0.clone(), max_hyp.state_1.clone()])? - .try_into() - .unwrap(); - decoder_stats.record(batch, t.elapsed()); + let n = token_scores.len(); + + // 1. Build batched inputs + let token_ids: Vec = + token_scores.iter().map(|&(ti, _)| ti as i32).collect(); + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + + let s0_src = max_hyp.state_0.view::()?; // [2, 1, 640] + let s0_batch: Value = + Array3::::from_shape_fn((2, n, 640), |(l, _, h)| s0_src[[l, 0, h]]) + .try_into()?; + + let s1_src = max_hyp.state_1.view::()?; + let s1_batch: Value = + Array3::::from_shape_fn((2, n, 640), |(l, _, h)| s1_src[[l, 0, h]]) + .try_into()?; + + // 2. Single decoder call + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = + self.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + decoder_stats.record(n, t.elapsed()); + + // 3. Slice outputs and push beams + let dec_arr = dec_out_b.view::()?; // [n, hidden] + let s0_arr = s0_b.view::()?; // [2, n, 640] + let s1_arr = s1_b.view::()?; // [2, n, 640] + + for (i, &(ti, lp)) in token_scores.iter().enumerate() { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; let mut new_tokens = max_hyp.tokens.clone(); new_tokens.push(ti); hyps.push(Beam { From f022d33419b0906e4cb25ed16b5760b2ebecf74e Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 05/20] nemo-parakeet-asr: batch joint+decoder across all active hypotheses in transcribe_beam MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each while-iteration now issues one joint call over all b active hypotheses ([b, enc_dim, 1] × [b, hidden, 1]) and one decoder call over all N token expansions ([N,1] + [2,N,640] states), collapsing what was previously b sequential joint calls and b sequential decoder calls into two batched calls. dec_out shape is [batch, hidden, 1] (not [batch, hidden]), so the batched gather uses Array3 with index [[0, h, 0]] rather than Array2 [[0, h]]. [beam][joint] avg_batch≈3.2 (up to BEAM_SIZE=4) [beam][decoder] avg_batch≈12.8 (up to b×BEAM_SIZE=16) --- examples/nemo-parakeet-asr/src/main.rs | 203 +++++++++++++++---------- 1 file changed, 119 insertions(+), 84 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 95663355d0..53dc2bde9d 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -149,6 +149,7 @@ impl TdtModel { let encoded: ArrayD = encoded.view()?.into_owned(); let batch = encoded.shape()[0]; let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; let mut decoder_stats = CallStats::default(); let mut joint_stats = CallStats::default(); @@ -193,103 +194,137 @@ impl TdtModel { } } - let frame: Value = - encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; - while !hyps.is_empty() { - let best_idx = hyps + let b = hyps.len(); + + // 1. JOINT: single call batched over all B active hypotheses + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(_, e, _)| { + enc_arr[[0, e, frame_ix]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let [logits_b] = + self.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + joint_stats.record(b, t.elapsed()); + + // 2. Per-hyp: token scores + duration expansions into kept + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; // [b, vocab+dur] + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = log_softmax(&row_slice[0..=self.blank_id]); + let dur_log_probs = log_softmax(&row_slice[self.blank_id + 1..]); + + let mut ts: Vec<(usize, f32)> = + (0..self.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(BEAM_SIZE); + per_hyp_token_scores.push(ts); + + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(DUR_BEAM_K); + for (di, dlp) in ds { + kept.push(Beam { + score: hyps[bi].score + log_probs[self.blank_id] + dlp, + tokens: hyps[bi].tokens.clone(), + last_frame: frame_ix + di, + dec_out: hyps[bi].dec_out.clone(), + state_0: hyps[bi].state_0.clone(), + state_1: hyps[bi].state_1.clone(), + }); + } + } + } // logits_arr dropped + + // 3. DECODER: single call batched over all N token expansions + let expansion_hyp_idxs: Vec = per_hyp_token_scores .iter() .enumerate() - .max_by_key(|(_, b)| FloatOrd(b.score)) - .map(|(i, _)| i) - .unwrap(); - let max_hyp = hyps.remove(best_idx); + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); - let t = Instant::now(); - let [logits] = self - .joint - .run([frame.clone(), max_hyp.dec_out.clone()])? - .try_into() - .unwrap(); - joint_stats.record(batch, t.elapsed()); - let logits = logits.view::()?; - let logits = logits.as_slice().unwrap(); - - let log_probs = log_softmax(&logits[0..=self.blank_id]); - let dur_log_probs = log_softmax(&logits[self.blank_id + 1..]); - - // Non-blank expansions: top BEAM_SIZE tokens, duration fixed at 0 - let mut token_scores: Vec<(usize, f32)> = - (0..self.blank_id).map(|ti| (ti, log_probs[ti])).collect(); - token_scores.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - token_scores.truncate(BEAM_SIZE); - - let n = token_scores.len(); - - // 1. Build batched inputs - let token_ids: Vec = - token_scores.iter().map(|&(ti, _)| ti as i32).collect(); let tokens_batch: Value = Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; - let s0_src = max_hyp.state_0.view::()?; // [2, 1, 640] - let s0_batch: Value = - Array3::::from_shape_fn((2, n, 640), |(l, _, h)| s0_src[[l, 0, h]]) - .try_into()?; - - let s1_src = max_hyp.state_1.view::()?; - let s1_batch: Value = - Array3::::from_shape_fn((2, n, 640), |(l, _, h)| s1_src[[l, 0, h]]) - .try_into()?; - - // 2. Single decoder call let t = Instant::now(); let [dec_out_b, s0_b, s1_b] = self.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); decoder_stats.record(n, t.elapsed()); - // 3. Slice outputs and push beams - let dec_arr = dec_out_b.view::()?; // [n, hidden] - let s0_arr = s0_b.view::()?; // [2, n, 640] - let s1_arr = s1_b.view::()?; // [2, n, 640] - - for (i, &(ti, lp)) in token_scores.iter().enumerate() { - let new_dec_out: Value = - dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; - let new_s0: Value = - s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; - let new_s1: Value = - s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; - let mut new_tokens = max_hyp.tokens.clone(); - new_tokens.push(ti); - hyps.push(Beam { - score: max_hyp.score + lp, - tokens: new_tokens, - last_frame: frame_ix, - dec_out: new_dec_out, - state_0: new_s0, - state_1: new_s1, - }); - } - - // Blank expansions: top DUR_BEAM_K non-zero durations - let mut dur_scores: Vec<(usize, f32)> = - (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); - dur_scores.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - dur_scores.truncate(DUR_BEAM_K); - - for (di, dlp) in dur_scores { - kept.push(Beam { - score: max_hyp.score + log_probs[self.blank_id] + dlp, - tokens: max_hyp.tokens.clone(), - last_frame: frame_ix + di, - dec_out: max_hyp.dec_out.clone(), - state_0: max_hyp.state_0.clone(), - state_1: max_hyp.state_1.clone(), - }); - } + // 4. Slice and build new beams + let new_hyps: Vec = { + let dec_arr = dec_out_b.view::()?; // [N, hidden] + let s0_arr = s0_b.view::()?; // [2, N, 640] + let s1_arr = s1_b.view::()?; // [2, N, 640] + let mut out = Vec::with_capacity(n); + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = hyps[bi].tokens.clone(); + new_tokens.push(ti); + out.push(Beam { + score: hyps[bi].score + lp, + tokens: new_tokens, + last_frame: frame_ix, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + out + }; + hyps = new_hyps; - // Prune combined pool to BEAM_SIZE + // 5. Prune combined pool to BEAM_SIZE let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); all.truncate(BEAM_SIZE); From cc418b38404800005a03a79f1d0695a5efa101ac Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 06/20] nemo-parakeet-asr: add clap CLI with configurable beam parameters Adds BeamConfig (--beam-size, --dur-beam-k) via clap 4 derive API, replacing the hard-coded BEAM_SIZE and DUR_BEAM_K constants. --- examples/nemo-parakeet-asr/Cargo.toml | 1 + examples/nemo-parakeet-asr/src/beam.rs | 230 ++++++++++++++++++ examples/nemo-parakeet-asr/src/main.rs | 307 +++---------------------- 3 files changed, 258 insertions(+), 280 deletions(-) create mode 100644 examples/nemo-parakeet-asr/src/beam.rs diff --git a/examples/nemo-parakeet-asr/Cargo.toml b/examples/nemo-parakeet-asr/Cargo.toml index a7fd2d6e53..9588ba1b90 100644 --- a/examples/nemo-parakeet-asr/Cargo.toml +++ b/examples/nemo-parakeet-asr/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] anyhow.workspace = true +clap = { version = "4", features = ["derive"] } float-ord.workspace = true hound = "3.5.1" itertools.workspace = true diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs new file mode 100644 index 0000000000..edaae411b2 --- /dev/null +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -0,0 +1,230 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct BeamConfig { + /// Number of active hypotheses to keep after each pruning step + #[arg(long, default_value_t = 4)] + pub beam_size: usize, + + /// Number of duration candidates to expand per hypothesis + #[arg(long, default_value_t = 2)] + pub dur_beam_k: usize, +} + +struct Beam { + score: f32, + tokens: Vec, + last_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_beam( + model: &crate::TdtModel, + wav: &[f32], + cfg: &BeamConfig, +) -> Result<(String, crate::CallStats, crate::CallStats)> { + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let [features, feat_len] = + model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [encoded, _lens] = + model.encoder.run([features, feat_len])?.try_into().unwrap(); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let mut decoder_stats = crate::CallStats::default(); + let mut joint_stats = crate::CallStats::default(); + + let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = + model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + decoder_stats.record(batch, t.elapsed()); + + let mut all_beams: Vec = vec![Beam { + score: 0.0, + tokens: vec![], + last_frame: 0, + dec_out, + state_0, + state_1, + }]; + + for frame_ix in 0..max_frames { + let mut hyps: Vec = Vec::new(); + let mut kept: Vec = Vec::new(); + for b in all_beams.drain(..) { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + + while !hyps.is_empty() { + let b = hyps.len(); + + // 1. JOINT: single call batched over all B active hypotheses + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(_, e, _)| { + enc_arr[[0, e, frame_ix]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let [logits_b] = + model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + joint_stats.record(b, t.elapsed()); + + // 2. Per-hyp: token scores + duration expansions into kept + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; // [b, vocab+dur] + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.beam_size); + per_hyp_token_scores.push(ts); + + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.dur_beam_k); + for (di, dlp) in ds { + kept.push(Beam { + score: hyps[bi].score + log_probs[model.blank_id] + dlp, + tokens: hyps[bi].tokens.clone(), + last_frame: frame_ix + di, + dec_out: hyps[bi].dec_out.clone(), + state_0: hyps[bi].state_0.clone(), + state_1: hyps[bi].state_1.clone(), + }); + } + } + } // logits_arr dropped + + // 3. DECODER: single call batched over all N token expansions + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = + model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + decoder_stats.record(n, t.elapsed()); + + // 4. Slice and build new beams + let new_hyps: Vec = { + let dec_arr = dec_out_b.view::()?; // [N, hidden] + let s0_arr = s0_b.view::()?; // [2, N, 640] + let s1_arr = s1_b.view::()?; // [2, N, 640] + let mut out = Vec::with_capacity(n); + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = hyps[bi].tokens.clone(); + new_tokens.push(ti); + out.push(Beam { + score: hyps[bi].score + lp, + tokens: new_tokens, + last_frame: frame_ix, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + out + }; + hyps = new_hyps; + + // 5. Prune combined pool to BEAM_SIZE + let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); + all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + all.truncate(cfg.beam_size); + for b in all { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + } + + all_beams = kept; + } + + let best = all_beams + .into_iter() + .max_by_key(|b| FloatOrd(b.score)) + .ok_or_else(|| anyhow!("no beams survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) +} diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 53dc2bde9d..e7a41dbab1 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -1,23 +1,32 @@ use std::fs::File; use std::path::Path; -use std::time::Instant; use anyhow::*; +use clap::Parser; use float_ord::FloatOrd; use itertools::Itertools; -use tract_rs::prelude::tract_ndarray::prelude::*; use tract_rs::prelude::*; use tract_rs::Nnef; +mod greedy; +mod beam; + +#[derive(Parser)] +#[command(about = "NeMo Parakeet ASR inference")] +struct Args { + #[command(flatten)] + beam: beam::BeamConfig, +} + #[derive(Default)] -struct CallStats { +pub(crate) struct CallStats { calls: u32, total_batch: u64, total_us: u64, } impl CallStats { - fn record(&mut self, batch: usize, elapsed: std::time::Duration) { + pub(crate) fn record(&mut self, batch: usize, elapsed: std::time::Duration) { self.calls += 1; self.total_batch += batch as u64; self.total_us += elapsed.as_micros() as u64; @@ -39,23 +48,23 @@ impl std::fmt::Debug for CallStats { } } -fn argmax(slice: &[f32]) -> Option { +pub(crate) fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } -fn log_softmax(xs: &[f32]) -> Vec { +pub(crate) fn log_softmax(xs: &[f32]) -> Vec { let max = xs.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let lse = xs.iter().map(|&x| (x - max).exp()).sum::().ln(); xs.iter().map(|&x| x - max - lse).collect() } -struct TdtModel { - preprocessor: Runnable, - encoder: Runnable, - decoder: Runnable, - joint: Runnable, - vocab: Vec, - blank_id: usize, +pub(crate) struct TdtModel { + pub(crate) preprocessor: Runnable, + pub(crate) encoder: Runnable, + pub(crate) decoder: Runnable, + pub(crate) joint: Runnable, + pub(crate) vocab: Vec, + pub(crate) blank_id: usize, } impl TdtModel { @@ -83,272 +92,10 @@ impl TdtModel { Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) } - - fn transcribe_greedy(&self, wav: &[f32]) -> Result<(String, CallStats, CallStats)> { - let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; - let len: Value = arr1(&[wav.len() as i64]).try_into()?; - - let [features, feat_len] = - self.preprocessor.run([samples, len])?.try_into().unwrap(); - let [encoded, _lens] = - self.encoder.run([features, feat_len])?.try_into().unwrap(); - - let encoded: ArrayD = encoded.view()?.into_owned(); - let batch = encoded.shape()[0]; - - let max_frames = encoded.shape()[2]; - let max_len = max_frames * 6 + 10; - - let mut decoder_stats = CallStats::default(); - let mut joint_stats = CallStats::default(); - - let mut hyp = vec![]; - let mut frame_ix = 0; - let mut token = Value::from_slice(&[1, 1], &[0i32])?; - let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; - let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; - - let t = Instant::now(); - [token, state_0, state_1] = - self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); - decoder_stats.record(batch, t.elapsed()); - - while hyp.len() < max_len && frame_ix < max_frames { - let frame: Value = - encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; - let t = Instant::now(); - let [logits] = self.joint.run([frame, token.clone()])?.try_into().unwrap(); - joint_stats.record(batch, t.elapsed()); - let logits = logits.view::()?; - let logits = logits.as_slice().unwrap(); - let token_id = argmax(&logits[0..self.blank_id + 1]).unwrap(); - if token_id == self.blank_id { - frame_ix += argmax(&logits[self.blank_id + 1..]).unwrap_or(0).max(1); - } else { - hyp.push(token_id); - token = Value::from_slice(&[1, 1], &[token_id as i32])?; - let t = Instant::now(); - [token, state_0, state_1] = - self.decoder.run([token, state_0, state_1])?.try_into().unwrap(); - decoder_stats.record(batch, t.elapsed()); - } - } - - Ok((hyp.into_iter().map(|t| self.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) - } - - fn transcribe_beam(&self, wav: &[f32]) -> Result<(String, CallStats, CallStats)> { - let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; - let len: Value = arr1(&[wav.len() as i64]).try_into()?; - - let [features, feat_len] = - self.preprocessor.run([samples, len])?.try_into().unwrap(); - let [encoded, _lens] = - self.encoder.run([features, feat_len])?.try_into().unwrap(); - - let encoded: ArrayD = encoded.view()?.into_owned(); - let batch = encoded.shape()[0]; - let max_frames = encoded.shape()[2]; - let enc_dim = encoded.shape()[1]; - - let mut decoder_stats = CallStats::default(); - let mut joint_stats = CallStats::default(); - - const BEAM_SIZE: usize = 4; - const DUR_BEAM_K: usize = 2; - - struct Beam { - score: f32, - tokens: Vec, - last_frame: usize, - dec_out: Value, - state_0: Value, - state_1: Value, - } - - let init_token = Value::from_slice(&[1, 1], &[0i32])?; - let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; - let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; - let t = Instant::now(); - let [dec_out, state_0, state_1] = - self.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); - decoder_stats.record(batch, t.elapsed()); - - let mut all_beams: Vec = vec![Beam { - score: 0.0, - tokens: vec![], - last_frame: 0, - dec_out, - state_0, - state_1, - }]; - - for frame_ix in 0..max_frames { - let mut hyps: Vec = Vec::new(); - let mut kept: Vec = Vec::new(); - for b in all_beams.drain(..) { - if b.last_frame == frame_ix { - hyps.push(b); - } else { - kept.push(b); - } - } - - while !hyps.is_empty() { - let b = hyps.len(); - - // 1. JOINT: single call batched over all B active hypotheses - let frame_batch: Value = { - let enc_arr = encoded.view(); - Array3::::from_shape_fn((b, enc_dim, 1), |(_, e, _)| { - enc_arr[[0, e, frame_ix]] - }) - .try_into()? - }; - let dec_out_batch: Value = { - let views: Vec<_> = hyps - .iter() - .map(|h| h.dec_out.view::()) - .collect::>>()?; - let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] - Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) - .try_into()? - }; - let t = Instant::now(); - let [logits_b] = - self.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); - joint_stats.record(b, t.elapsed()); - - // 2. Per-hyp: token scores + duration expansions into kept - let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); - { - let logits_arr = logits_b.view::()?; // [b, vocab+dur] - for bi in 0..b { - let row = logits_arr.index_axis(Axis(0), bi); - let row_slice = row.as_slice().unwrap(); - let log_probs = log_softmax(&row_slice[0..=self.blank_id]); - let dur_log_probs = log_softmax(&row_slice[self.blank_id + 1..]); - - let mut ts: Vec<(usize, f32)> = - (0..self.blank_id).map(|ti| (ti, log_probs[ti])).collect(); - ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ts.truncate(BEAM_SIZE); - per_hyp_token_scores.push(ts); - - let mut ds: Vec<(usize, f32)> = - (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); - ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ds.truncate(DUR_BEAM_K); - for (di, dlp) in ds { - kept.push(Beam { - score: hyps[bi].score + log_probs[self.blank_id] + dlp, - tokens: hyps[bi].tokens.clone(), - last_frame: frame_ix + di, - dec_out: hyps[bi].dec_out.clone(), - state_0: hyps[bi].state_0.clone(), - state_1: hyps[bi].state_1.clone(), - }); - } - } - } // logits_arr dropped - - // 3. DECODER: single call batched over all N token expansions - let expansion_hyp_idxs: Vec = per_hyp_token_scores - .iter() - .enumerate() - .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) - .collect(); - let token_ids: Vec = per_hyp_token_scores - .iter() - .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) - .collect(); - let n = token_ids.len(); - - let tokens_batch: Value = - Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; - let s0_batch: Value = { - let views: Vec<_> = hyps - .iter() - .map(|h| h.state_0.view::()) - .collect::>>()?; - Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { - views[expansion_hyp_idxs[i]][[l, 0, h]] - }) - .try_into()? - }; - let s1_batch: Value = { - let views: Vec<_> = hyps - .iter() - .map(|h| h.state_1.view::()) - .collect::>>()?; - Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { - views[expansion_hyp_idxs[i]][[l, 0, h]] - }) - .try_into()? - }; - - let t = Instant::now(); - let [dec_out_b, s0_b, s1_b] = - self.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); - decoder_stats.record(n, t.elapsed()); - - // 4. Slice and build new beams - let new_hyps: Vec = { - let dec_arr = dec_out_b.view::()?; // [N, hidden] - let s0_arr = s0_b.view::()?; // [2, N, 640] - let s1_arr = s1_b.view::()?; // [2, N, 640] - let mut out = Vec::with_capacity(n); - let mut i = 0; - for (bi, ts) in per_hyp_token_scores.iter().enumerate() { - for &(ti, lp) in ts { - let new_dec_out: Value = - dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; - let new_s0: Value = - s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; - let new_s1: Value = - s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; - let mut new_tokens = hyps[bi].tokens.clone(); - new_tokens.push(ti); - out.push(Beam { - score: hyps[bi].score + lp, - tokens: new_tokens, - last_frame: frame_ix, - dec_out: new_dec_out, - state_0: new_s0, - state_1: new_s1, - }); - i += 1; - } - } - out - }; - hyps = new_hyps; - - // 5. Prune combined pool to BEAM_SIZE - let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); - all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); - all.truncate(BEAM_SIZE); - for b in all { - if b.last_frame == frame_ix { - hyps.push(b); - } else { - kept.push(b); - } - } - } - - all_beams = kept; - } - - let best = all_beams - .into_iter() - .max_by_key(|b| FloatOrd(b.score)) - .ok_or_else(|| anyhow!("no beams survived"))?; - Ok((best.tokens.into_iter().map(|t| self.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) - } } fn main() -> anyhow::Result<()> { + let args = Args::parse(); let nnef = tract_rs::nnef()?.with_tract_core()?.with_tract_transformers()?; let gpu = ["cuda", "metal", "default"] .iter() @@ -362,13 +109,13 @@ fn main() -> anyhow::Result<()> { .map(|x| x.unwrap() as f32) .collect(); - model.transcribe_greedy(&wav)?; - let (transcript_g, dec, joint) = model.transcribe_greedy(&wav)?; + greedy::transcribe_greedy(&model, &wav)?; + let (transcript_g, dec, joint) = greedy::transcribe_greedy(&model, &wav)?; eprintln!("[greedy][decoder] {dec:?}"); eprintln!("[greedy][joint] {joint:?}"); - model.transcribe_beam(&wav)?; - let (transcript_b, dec, joint) = model.transcribe_beam(&wav)?; + beam::transcribe_beam(&model, &wav, &args.beam)?; + let (transcript_b, dec, joint) = beam::transcribe_beam(&model, &wav, &args.beam)?; eprintln!("[beam][decoder] {dec:?}"); eprintln!("[beam][joint] {joint:?}"); From 756b1e4fb8f8b600ba6fb0ddeccf3b11e4afabf1 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 07/20] nemo-parakeet-asr: add ALSD beam decoder --- examples/nemo-parakeet-asr/src/alsd.rs | 237 +++++++++++++++++++++++++ examples/nemo-parakeet-asr/src/main.rs | 13 ++ 2 files changed, 250 insertions(+) create mode 100644 examples/nemo-parakeet-asr/src/alsd.rs diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs new file mode 100644 index 0000000000..972f430512 --- /dev/null +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -0,0 +1,237 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct AlsdConfig { + /// Beam width for ALSD decoding + #[arg(long, default_value_t = 4)] + pub alsd_beam_size: usize, + + /// Duration candidates per hypothesis in ALSD + #[arg(long, default_value_t = 2)] + pub alsd_dur_beam_k: usize, + + /// Max non-blank tokens emitted per frame per hypothesis + #[arg(long, default_value_t = 10)] + pub alsd_max_symbols_per_frame: usize, +} + +struct AlsdHyp { + score: f32, + tokens: Vec, + current_frame: usize, + symbols_this_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_alsd( + model: &crate::TdtModel, + wav: &[f32], + cfg: &AlsdConfig, +) -> Result<(String, crate::CallStats, crate::CallStats)> { + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let [features, feat_len] = + model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [encoded, _lens] = + model.encoder.run([features, feat_len])?.try_into().unwrap(); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let mut decoder_stats = crate::CallStats::default(); + let mut joint_stats = crate::CallStats::default(); + + let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = + model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + decoder_stats.record(batch, t.elapsed()); + + let mut beam: Vec = vec![AlsdHyp { + score: 0.0, + tokens: vec![], + current_frame: 0, + symbols_this_frame: 0, + dec_out, + state_0, + state_1, + }]; + + loop { + // Split into active (still have frames to consume) and completed + let mut active: Vec = Vec::new(); + let mut completed: Vec = Vec::new(); + for h in beam.drain(..) { + if h.current_frame < max_frames { + active.push(h); + } else { + completed.push(h); + } + } + + if active.is_empty() { + beam = completed; + break; + } + + let b = active.len(); + + // 1. JOINT: each hypothesis uses its own current_frame + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(bi, e, _)| { + enc_arr[[0, e, active[bi].current_frame]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let [logits_b] = + model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + joint_stats.record(b, t.elapsed()); + + // 2. Per-hyp: blank+duration → next; non-blank → expansion list + let mut next: Vec = completed; + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; // [b, vocab+dur] + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + // Blank + duration expansions advance the frame and reset symbol count + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.alsd_dur_beam_k); + for (di, dlp) in ds { + next.push(AlsdHyp { + score: active[bi].score + log_probs[model.blank_id] + dlp, + tokens: active[bi].tokens.clone(), + current_frame: active[bi].current_frame + di, + symbols_this_frame: 0, + dec_out: active[bi].dec_out.clone(), + state_0: active[bi].state_0.clone(), + state_1: active[bi].state_1.clone(), + }); + } + + // Non-blank expansions stay on the same frame (symbol count checked) + if active[bi].symbols_this_frame < cfg.alsd_max_symbols_per_frame { + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.alsd_beam_size); + per_hyp_token_scores.push(ts); + } else { + per_hyp_token_scores.push(vec![]); + } + } + } // logits_arr dropped + + // 3. DECODER: single batched call over all non-blank expansions + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + if n > 0 { + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = + model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + decoder_stats.record(n, t.elapsed()); + + // 4. Slice per-expansion outputs and push into next + let dec_arr = dec_out_b.view::()?; // [N, hidden, 1] + let s0_arr = s0_b.view::()?; // [2, N, 640] + let s1_arr = s1_b.view::()?; // [2, N, 640] + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = active[bi].tokens.clone(); + new_tokens.push(ti); + next.push(AlsdHyp { + score: active[bi].score + lp, + tokens: new_tokens, + current_frame: active[bi].current_frame, + symbols_this_frame: active[bi].symbols_this_frame + 1, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + } + + // 5. Global prune + next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + next.truncate(cfg.alsd_beam_size); + beam = next; + } + + let best = beam + .into_iter() + .max_by_key(|b| FloatOrd(b.score)) + .ok_or_else(|| anyhow!("no beams survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) +} diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index e7a41dbab1..2cd784f7f6 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -10,12 +10,15 @@ use tract_rs::Nnef; mod greedy; mod beam; +mod alsd; #[derive(Parser)] #[command(about = "NeMo Parakeet ASR inference")] struct Args { #[command(flatten)] beam: beam::BeamConfig, + #[command(flatten)] + alsd: alsd::AlsdConfig, } #[derive(Default)] @@ -119,8 +122,14 @@ fn main() -> anyhow::Result<()> { eprintln!("[beam][decoder] {dec:?}"); eprintln!("[beam][joint] {joint:?}"); + alsd::transcribe_alsd(&model, &wav, &args.alsd)?; + let (transcript_a, dec, joint) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; + eprintln!("[alsd][decoder] {dec:?}"); + eprintln!("[alsd][joint] {joint:?}"); + println!("Greedy: {transcript_g}"); println!("Beam: {transcript_b}"); + println!("ALSD: {transcript_a}"); assert_eq!( transcript_g, "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." @@ -129,5 +138,9 @@ fn main() -> anyhow::Result<()> { transcript_b, "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." ); + assert_eq!( + transcript_a, + "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." + ); Ok(()) } From 4c81bd2d73f6007f32de25caee41094023ebef69 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 08/20] nemo-parakeet-asr: fuse duplicate hypotheses at pruning step in beam and ALSD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At each prune, sort by score then drop any candidate whose key was already seen — guaranteeing the survivor is always the best-scoring one. Key is (tokens, last_frame) for beam and (tokens, current_frame, symbols_this_frame) for ALSD, where the extra field is needed because hypotheses at the same frame with different symbol counts follow different future paths. --- examples/nemo-parakeet-asr/src/alsd.rs | 6 +++++- examples/nemo-parakeet-asr/src/beam.rs | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index 972f430512..9bad1a092d 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -223,8 +223,12 @@ pub fn transcribe_alsd( } } - // 5. Global prune + // 5. Fuse duplicates then prune to alsd_beam_size. + // Key is (tokens, current_frame, symbols_this_frame): same tokens → same decoder + // state; same frame → same encoder embedding; same symbol count → same gate behavior. next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize, usize)>::new(); + next.retain(|h| seen.insert((h.tokens.clone(), h.current_frame, h.symbols_this_frame))); next.truncate(cfg.alsd_beam_size); beam = next; } diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index edaae411b2..aece428b46 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -206,9 +206,13 @@ pub fn transcribe_beam( }; hyps = new_hyps; - // 5. Prune combined pool to BEAM_SIZE + // 5. Fuse duplicates then prune to BEAM_SIZE. + // After sorting by score, the first occurrence of each (tokens, last_frame) + // key is the best one; retain only that first occurrence. let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize)>::new(); + all.retain(|h| seen.insert((h.tokens.clone(), h.last_frame))); all.truncate(cfg.beam_size); for b in all { if b.last_frame == frame_ix { From 8febcd7fc8e55bd31c4747612cc97bff205d0f1c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 09/20] nemo-parakeet-asr: consolidate stats into DecodingStats, add preprocessor/encoder timing, nn/host split Replace the pair of CallStats return values with a single DecodingStats that covers all four stages (preprocessor, encoder, decoder, joint). The summary line now shows total elapsed, RTFx, nn time (sum of all model calls), and host time (elapsed - nn: search logic, batching, tensor prep). --- examples/nemo-parakeet-asr/src/alsd.rs | 20 ++++--- examples/nemo-parakeet-asr/src/beam.rs | 20 ++++--- examples/nemo-parakeet-asr/src/greedy.rs | 64 +++++++++++++++++++++++ examples/nemo-parakeet-asr/src/main.rs | 66 ++++++++++++++++++++---- 4 files changed, 143 insertions(+), 27 deletions(-) create mode 100644 examples/nemo-parakeet-asr/src/greedy.rs diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index 9bad1a092d..3dd012ddbe 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -36,30 +36,34 @@ pub fn transcribe_alsd( model: &crate::TdtModel, wav: &[f32], cfg: &AlsdConfig, -) -> Result<(String, crate::CallStats, crate::CallStats)> { +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; let len: Value = arr1(&[wav.len() as i64]).try_into()?; + let t = Instant::now(); let [features, feat_len] = model.preprocessor.run([samples, len])?.try_into().unwrap(); + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); let [encoded, _lens] = model.encoder.run([features, feat_len])?.try_into().unwrap(); + stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); let batch = encoded.shape()[0]; let max_frames = encoded.shape()[2]; let enc_dim = encoded.shape()[1]; - let mut decoder_stats = crate::CallStats::default(); - let mut joint_stats = crate::CallStats::default(); - let init_token = Value::from_slice(&[1, 1], &[0i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); let [dec_out, state_0, state_1] = model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); - decoder_stats.record(batch, t.elapsed()); + stats.decoder.record(batch, t.elapsed()); let mut beam: Vec = vec![AlsdHyp { score: 0.0, @@ -110,7 +114,7 @@ pub fn transcribe_alsd( let t = Instant::now(); let [logits_b] = model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); - joint_stats.record(b, t.elapsed()); + stats.joint.record(b, t.elapsed()); // 2. Per-hyp: blank+duration → next; non-blank → expansion list let mut next: Vec = completed; @@ -192,7 +196,7 @@ pub fn transcribe_alsd( let t = Instant::now(); let [dec_out_b, s0_b, s1_b] = model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); - decoder_stats.record(n, t.elapsed()); + stats.decoder.record(n, t.elapsed()); // 4. Slice per-expansion outputs and push into next let dec_arr = dec_out_b.view::()?; // [N, hidden, 1] @@ -237,5 +241,5 @@ pub fn transcribe_alsd( .into_iter() .max_by_key(|b| FloatOrd(b.score)) .ok_or_else(|| anyhow!("no beams survived"))?; - Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) } diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index aece428b46..219b9703f3 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -31,30 +31,34 @@ pub fn transcribe_beam( model: &crate::TdtModel, wav: &[f32], cfg: &BeamConfig, -) -> Result<(String, crate::CallStats, crate::CallStats)> { +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; let len: Value = arr1(&[wav.len() as i64]).try_into()?; + let t = Instant::now(); let [features, feat_len] = model.preprocessor.run([samples, len])?.try_into().unwrap(); + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); let [encoded, _lens] = model.encoder.run([features, feat_len])?.try_into().unwrap(); + stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); let batch = encoded.shape()[0]; let max_frames = encoded.shape()[2]; let enc_dim = encoded.shape()[1]; - let mut decoder_stats = crate::CallStats::default(); - let mut joint_stats = crate::CallStats::default(); - let init_token = Value::from_slice(&[1, 1], &[0i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); let [dec_out, state_0, state_1] = model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); - decoder_stats.record(batch, t.elapsed()); + stats.decoder.record(batch, t.elapsed()); let mut all_beams: Vec = vec![Beam { score: 0.0, @@ -99,7 +103,7 @@ pub fn transcribe_beam( let t = Instant::now(); let [logits_b] = model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); - joint_stats.record(b, t.elapsed()); + stats.joint.record(b, t.elapsed()); // 2. Per-hyp: token scores + duration expansions into kept let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); @@ -172,7 +176,7 @@ pub fn transcribe_beam( let t = Instant::now(); let [dec_out_b, s0_b, s1_b] = model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); - decoder_stats.record(n, t.elapsed()); + stats.decoder.record(n, t.elapsed()); // 4. Slice and build new beams let new_hyps: Vec = { @@ -230,5 +234,5 @@ pub fn transcribe_beam( .into_iter() .max_by_key(|b| FloatOrd(b.score)) .ok_or_else(|| anyhow!("no beams survived"))?; - Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), decoder_stats, joint_stats)) + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) } diff --git a/examples/nemo-parakeet-asr/src/greedy.rs b/examples/nemo-parakeet-asr/src/greedy.rs new file mode 100644 index 0000000000..2e1521227a --- /dev/null +++ b/examples/nemo-parakeet-asr/src/greedy.rs @@ -0,0 +1,64 @@ +use std::time::Instant; + +use anyhow::*; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +pub fn transcribe_greedy( + model: &crate::TdtModel, + wav: &[f32], +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = + model.preprocessor.run([samples, len])?.try_into().unwrap(); + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = + model.encoder.run([features, feat_len])?.try_into().unwrap(); + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let max_frames = encoded.shape()[2]; + let max_len = max_frames * 6 + 10; + + let mut hyp = vec![]; + let mut frame_ix = 0; + let mut token = Value::from_slice(&[1, 1], &[0i32])?; + let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + + let t = Instant::now(); + [token, state_0, state_1] = + model.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + stats.decoder.record(1, t.elapsed()); + + while hyp.len() < max_len && frame_ix < max_frames { + let frame: Value = + encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; + let t = Instant::now(); + let [logits] = model.joint.run([frame, token.clone()])?.try_into().unwrap(); + stats.joint.record(1, t.elapsed()); + let logits = logits.view::()?; + let logits = logits.as_slice().unwrap(); + let token_id = crate::argmax(&logits[0..model.blank_id + 1]).unwrap(); + if token_id == model.blank_id { + frame_ix += crate::argmax(&logits[model.blank_id + 1..]).unwrap_or(0).max(1); + } else { + hyp.push(token_id); + token = Value::from_slice(&[1, 1], &[token_id as i32])?; + let t = Instant::now(); + [token, state_0, state_1] = + model.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + stats.decoder.record(1, t.elapsed()); + } + } + + Ok((hyp.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 2cd784f7f6..109ebc15aa 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -34,6 +34,10 @@ impl CallStats { self.total_batch += batch as u64; self.total_us += elapsed.as_micros() as u64; } + + pub(crate) fn total_ms(&self) -> f64 { + self.total_us as f64 / 1000.0 + } } impl std::fmt::Debug for CallStats { @@ -51,6 +55,23 @@ impl std::fmt::Debug for CallStats { } } +#[derive(Default)] +pub(crate) struct DecodingStats { + pub(crate) preprocessor: CallStats, + pub(crate) encoder: CallStats, + pub(crate) decoder: CallStats, + pub(crate) joint: CallStats, +} + +impl DecodingStats { + pub(crate) fn nn_ms(&self) -> f64 { + self.preprocessor.total_ms() + + self.encoder.total_ms() + + self.decoder.total_ms() + + self.joint.total_ms() + } +} + pub(crate) fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } @@ -107,25 +128,48 @@ fn main() -> anyhow::Result<()> { let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; - let wav: Vec = hound::WavReader::open("assets/2086-149220-0033.wav")? - .samples::() + let mut wav_reader = hound::WavReader::open("assets/2086-149220-0033.wav")?; + let sample_rate = wav_reader.spec().sample_rate as f64; + let wav: Vec = wav_reader.samples::() .map(|x| x.unwrap() as f32) .collect(); + let audio_s = wav.len() as f64 / sample_rate; greedy::transcribe_greedy(&model, &wav)?; - let (transcript_g, dec, joint) = greedy::transcribe_greedy(&model, &wav)?; - eprintln!("[greedy][decoder] {dec:?}"); - eprintln!("[greedy][joint] {joint:?}"); + let t = std::time::Instant::now(); + let (transcript_g, stats) = greedy::transcribe_greedy(&model, &wav)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + let nn_ms = stats.nn_ms(); + eprintln!("[greedy] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); beam::transcribe_beam(&model, &wav, &args.beam)?; - let (transcript_b, dec, joint) = beam::transcribe_beam(&model, &wav, &args.beam)?; - eprintln!("[beam][decoder] {dec:?}"); - eprintln!("[beam][joint] {joint:?}"); + let t = std::time::Instant::now(); + let (transcript_b, stats) = beam::transcribe_beam(&model, &wav, &args.beam)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + let nn_ms = stats.nn_ms(); + eprintln!("[beam] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); alsd::transcribe_alsd(&model, &wav, &args.alsd)?; - let (transcript_a, dec, joint) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; - eprintln!("[alsd][decoder] {dec:?}"); - eprintln!("[alsd][joint] {joint:?}"); + let t = std::time::Instant::now(); + let (transcript_a, stats) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + let nn_ms = stats.nn_ms(); + eprintln!("[alsd] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); println!("Greedy: {transcript_g}"); println!("Beam: {transcript_b}"); From b51512dbff6b713ef3e99212d6b9ca0cbf60014c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 10/20] nemo-parakeet-asr: multi-file CLI with reference beam and tick/cross output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Accept positional wav paths/dirs on the CLI (dirs recurse, sorted) - Run a high-quality reference beam (beam_size=10, dur_beam_k=5) per file as silent ground truth; show transcript + duration on the file header line - Print green ✓ / red ✗ (with ref/got lines) instead of assert_eq! - Indent greedy/beam/alsd blocks under the file name - Clean SentencePiece ▁ markers in displayed transcripts - Warmup all four decoders on the first file before timing --- examples/nemo-parakeet-asr/src/main.rs | 170 ++++++++++++++++--------- 1 file changed, 111 insertions(+), 59 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 109ebc15aa..ef245854f5 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -1,5 +1,6 @@ use std::fs::File; -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::time::Instant; use anyhow::*; use clap::Parser; @@ -19,6 +20,8 @@ struct Args { beam: beam::BeamConfig, #[command(flatten)] alsd: alsd::AlsdConfig, + #[arg(required = true)] + inputs: Vec, } #[derive(Default)] @@ -118,6 +121,72 @@ impl TdtModel { } } +fn collect_wavs_from_dir(dir: &Path) -> Vec { + let mut results = Vec::new(); + if let Some(entries) = std::fs::read_dir(dir).ok() { + for entry in entries.flatten() { + let path: PathBuf = entry.path(); + if path.is_dir() { + results.extend(collect_wavs_from_dir(&path)); + } else if path.extension().and_then(|e: &std::ffi::OsStr| e.to_str()) == Some("wav") { + results.push(path); + } + } + } + results +} + +fn collect_wavs(inputs: &[PathBuf]) -> Vec { + let mut results = Vec::new(); + for input in inputs { + if input.is_dir() { + results.extend(collect_wavs_from_dir(input)); + } else { + results.push(input.clone()); + } + } + results.sort(); + results +} + +fn load_wav(path: &Path) -> Result<(Vec, f64)> { + let mut wav_reader = hound::WavReader::open(path)?; + let sample_rate = wav_reader.spec().sample_rate as f64; + let samples: Vec = wav_reader.samples::() + .map(|x| x.unwrap() as f32) + .collect(); + let audio_duration_s = samples.len() as f64 / sample_rate; + Ok((samples, audio_duration_s)) +} + +fn clean(s: &str) -> String { + s.replace('▁', " ").trim_start().to_owned() +} + +fn print_result( + label: &str, + transcript: &str, + reference: &str, + elapsed_ms: f64, + rtfx: f64, + nn_ms: f64, + stats: &DecodingStats, +) { + let mark = if transcript == reference { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; + eprintln!( + " [{label}] {elapsed_ms:.1}ms RTFx={rtfx:.1} nn={nn_ms:.1}ms host={:.1}ms {mark}", + elapsed_ms - nn_ms + ); + if transcript != reference { + eprintln!(" ref: {}", clean(reference)); + eprintln!(" got: {}", clean(transcript)); + } + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); +} + fn main() -> anyhow::Result<()> { let args = Args::parse(); let nnef = tract_rs::nnef()?.with_tract_core()?.with_tract_transformers()?; @@ -128,63 +197,46 @@ fn main() -> anyhow::Result<()> { let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; - let mut wav_reader = hound::WavReader::open("assets/2086-149220-0033.wav")?; - let sample_rate = wav_reader.spec().sample_rate as f64; - let wav: Vec = wav_reader.samples::() - .map(|x| x.unwrap() as f32) - .collect(); - let audio_s = wav.len() as f64 / sample_rate; - - greedy::transcribe_greedy(&model, &wav)?; - let t = std::time::Instant::now(); - let (transcript_g, stats) = greedy::transcribe_greedy(&model, &wav)?; - let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - let nn_ms = stats.nn_ms(); - eprintln!("[greedy] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); - eprintln!(" [preprocessor] {:?}", stats.preprocessor); - eprintln!(" [encoder] {:?}", stats.encoder); - eprintln!(" [decoder] {:?}", stats.decoder); - eprintln!(" [joint] {:?}", stats.joint); - - beam::transcribe_beam(&model, &wav, &args.beam)?; - let t = std::time::Instant::now(); - let (transcript_b, stats) = beam::transcribe_beam(&model, &wav, &args.beam)?; - let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - let nn_ms = stats.nn_ms(); - eprintln!("[beam] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); - eprintln!(" [preprocessor] {:?}", stats.preprocessor); - eprintln!(" [encoder] {:?}", stats.encoder); - eprintln!(" [decoder] {:?}", stats.decoder); - eprintln!(" [joint] {:?}", stats.joint); - - alsd::transcribe_alsd(&model, &wav, &args.alsd)?; - let t = std::time::Instant::now(); - let (transcript_a, stats) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; - let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - let nn_ms = stats.nn_ms(); - eprintln!("[alsd] {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", audio_s / elapsed, elapsed_ms - nn_ms); - eprintln!(" [preprocessor] {:?}", stats.preprocessor); - eprintln!(" [encoder] {:?}", stats.encoder); - eprintln!(" [decoder] {:?}", stats.decoder); - eprintln!(" [joint] {:?}", stats.joint); - - println!("Greedy: {transcript_g}"); - println!("Beam: {transcript_b}"); - println!("ALSD: {transcript_a}"); - assert_eq!( - transcript_g, - "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." - ); - assert_eq!( - transcript_b, - "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." - ); - assert_eq!( - transcript_a, - "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." - ); + let gt_cfg = beam::BeamConfig { beam_size: 10, dur_beam_k: 5 }; + let wavs = collect_wavs(&args.inputs); + + // Warmup on first file (all 4 decoders) + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + greedy::transcribe_greedy(&model, &wav)?; + beam::transcribe_beam(&model, &wav, &args.beam)?; + alsd::transcribe_alsd(&model, &wav, &args.alsd)?; + beam::transcribe_beam(&model, &wav, >_cfg)?; + } + + for wav_path in &wavs { + let (wav, audio_s) = load_wav(wav_path)?; + + // Reference — high-quality beam as ground truth (silent) + let (reference, _) = beam::transcribe_beam(&model, &wav, >_cfg)?; + eprintln!("{} ({audio_s:.1}s) {}", wav_path.display(), clean(&reference)); + + // Greedy + let t = Instant::now(); + let (transcript, stats) = greedy::transcribe_greedy(&model, &wav)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + print_result("greedy", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); + + // Beam + let t = Instant::now(); + let (transcript, stats) = beam::transcribe_beam(&model, &wav, &args.beam)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + print_result("beam", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); + + // ALSD + let t = Instant::now(); + let (transcript, stats) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; + let elapsed = t.elapsed().as_secs_f64(); + let elapsed_ms = elapsed * 1000.0; + print_result("alsd", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); + } + Ok(()) } From 0d06180022eca6ab4a555c0356caa4a67eae82d4 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 11/20] nemo-parakeet-asr: single-decoder CLI with summary, progress bar, and display options MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Select algorithm via --decoder greedy|beam|alsd (default: greedy) - Run only reference + one decoder per file; accumulate RTFx and exact-match count - Per-file line: filename, ✓/✗, signal duration, decoding time, RTFx, transcript - --stats: show per-sub-model timing breakdown under each file - --no-details: suppress per-file lines and header; show a progress bar with ETA - Summary line (always shown): algo+params, N/total exact, overall RTFx - Rename dur-beam-k -> beam-dur-k (and alsd equivalent) for consistent prefix nesting - Add progress_bar 1.4.0 dependency --- examples/nemo-parakeet-asr/Cargo.toml | 1 + examples/nemo-parakeet-asr/src/alsd.rs | 4 +- examples/nemo-parakeet-asr/src/beam.rs | 4 +- examples/nemo-parakeet-asr/src/main.rs | 115 +++++++++++++++---------- 4 files changed, 75 insertions(+), 49 deletions(-) diff --git a/examples/nemo-parakeet-asr/Cargo.toml b/examples/nemo-parakeet-asr/Cargo.toml index 9588ba1b90..a8df64c339 100644 --- a/examples/nemo-parakeet-asr/Cargo.toml +++ b/examples/nemo-parakeet-asr/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true clap = { version = "4", features = ["derive"] } float-ord.workspace = true hound = "3.5.1" +progress_bar = "1.4.0" itertools.workspace = true serde_json.workspace = true tract-rs.workspace = true diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index 3dd012ddbe..f9853f3060 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -15,7 +15,7 @@ pub struct AlsdConfig { /// Duration candidates per hypothesis in ALSD #[arg(long, default_value_t = 2)] - pub alsd_dur_beam_k: usize, + pub alsd_beam_dur_k: usize, /// Max non-blank tokens emitted per frame per hypothesis #[arg(long, default_value_t = 10)] @@ -131,7 +131,7 @@ pub fn transcribe_alsd( let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ds.truncate(cfg.alsd_dur_beam_k); + ds.truncate(cfg.alsd_beam_dur_k); for (di, dlp) in ds { next.push(AlsdHyp { score: active[bi].score + log_probs[model.blank_id] + dlp, diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index 219b9703f3..2be4b4e0ee 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -15,7 +15,7 @@ pub struct BeamConfig { /// Number of duration candidates to expand per hypothesis #[arg(long, default_value_t = 2)] - pub dur_beam_k: usize, + pub beam_dur_k: usize, } struct Beam { @@ -124,7 +124,7 @@ pub fn transcribe_beam( let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ds.truncate(cfg.dur_beam_k); + ds.truncate(cfg.beam_dur_k); for (di, dlp) in ds { kept.push(Beam { score: hyps[bi].score + log_probs[model.blank_id] + dlp, diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index ef245854f5..72b9eed40d 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -4,6 +4,7 @@ use std::time::Instant; use anyhow::*; use clap::Parser; +use progress_bar::*; use float_ord::FloatOrd; use itertools::Itertools; use tract_rs::prelude::*; @@ -13,6 +14,9 @@ mod greedy; mod beam; mod alsd; +#[derive(clap::ValueEnum, Clone)] +enum Decoder { Greedy, Beam, Alsd } + #[derive(Parser)] #[command(about = "NeMo Parakeet ASR inference")] struct Args { @@ -20,6 +24,12 @@ struct Args { beam: beam::BeamConfig, #[command(flatten)] alsd: alsd::AlsdConfig, + #[arg(long, value_enum, default_value_t = Decoder::Greedy)] + decoder: Decoder, + #[arg(long)] + stats: bool, + #[arg(long)] + no_details: bool, #[arg(required = true)] inputs: Vec, } @@ -163,28 +173,20 @@ fn clean(s: &str) -> String { s.replace('▁', " ").trim_start().to_owned() } -fn print_result( - label: &str, - transcript: &str, - reference: &str, - elapsed_ms: f64, - rtfx: f64, - nn_ms: f64, - stats: &DecodingStats, -) { - let mark = if transcript == reference { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; - eprintln!( - " [{label}] {elapsed_ms:.1}ms RTFx={rtfx:.1} nn={nn_ms:.1}ms host={:.1}ms {mark}", - elapsed_ms - nn_ms - ); - if transcript != reference { - eprintln!(" ref: {}", clean(reference)); - eprintln!(" got: {}", clean(transcript)); +fn decoder_label(args: &Args) -> String { + match args.decoder { + Decoder::Greedy => "greedy".to_string(), + Decoder::Beam => format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k), + Decoder::Alsd => format!("alsd beam_size={} beam_dur_k={} max_symbols={}", args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_max_symbols_per_frame), + } +} + +fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, DecodingStats)> { + match args.decoder { + Decoder::Greedy => greedy::transcribe_greedy(model, wav), + Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), + Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), } - eprintln!(" [preprocessor] {:?}", stats.preprocessor); - eprintln!(" [encoder] {:?}", stats.encoder); - eprintln!(" [decoder] {:?}", stats.decoder); - eprintln!(" [joint] {:?}", stats.joint); } fn main() -> anyhow::Result<()> { @@ -197,46 +199,69 @@ fn main() -> anyhow::Result<()> { let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; - let gt_cfg = beam::BeamConfig { beam_size: 10, dur_beam_k: 5 }; + let gt_cfg = beam::BeamConfig { beam_size: 10, beam_dur_k: 5 }; let wavs = collect_wavs(&args.inputs); - // Warmup on first file (all 4 decoders) + let label = decoder_label(&args); + if !args.no_details { + eprintln!("{} files={}", label, wavs.len()); + } else { + init_progress_bar_with_eta(wavs.len()); + set_progress_bar_action("Decoding", Color::Blue, Style::Bold); + } + + // Warmup on first file if let Some(first) = wavs.first() { let (wav, _) = load_wav(first)?; - greedy::transcribe_greedy(&model, &wav)?; - beam::transcribe_beam(&model, &wav, &args.beam)?; - alsd::transcribe_alsd(&model, &wav, &args.alsd)?; beam::transcribe_beam(&model, &wav, >_cfg)?; + run_decoder(&model, &wav, &args)?; } + let mut total_audio_s = 0.0f64; + let mut total_elapsed_s = 0.0f64; + let mut exact = 0usize; + for wav_path in &wavs { let (wav, audio_s) = load_wav(wav_path)?; - - // Reference — high-quality beam as ground truth (silent) let (reference, _) = beam::transcribe_beam(&model, &wav, >_cfg)?; - eprintln!("{} ({audio_s:.1}s) {}", wav_path.display(), clean(&reference)); - // Greedy let t = Instant::now(); - let (transcript, stats) = greedy::transcribe_greedy(&model, &wav)?; + let (transcript, stats) = run_decoder(&model, &wav, &args)?; let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - print_result("greedy", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); - // Beam - let t = Instant::now(); - let (transcript, stats) = beam::transcribe_beam(&model, &wav, &args.beam)?; - let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - print_result("beam", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); + total_audio_s += audio_s; + total_elapsed_s += elapsed; - // ALSD - let t = Instant::now(); - let (transcript, stats) = alsd::transcribe_alsd(&model, &wav, &args.alsd)?; - let elapsed = t.elapsed().as_secs_f64(); - let elapsed_ms = elapsed * 1000.0; - print_result("alsd", &transcript, &reference, elapsed_ms, audio_s / elapsed, stats.nn_ms(), &stats); + let ok = transcript == reference; + if ok { exact += 1; } + + if args.no_details { + inc_progress_bar(); + } else { + let mark = if ok { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; + let elapsed_ms = elapsed * 1000.0; + let nn_ms = stats.nn_ms(); + eprintln!("{} {mark} {audio_s:.1}s {elapsed_ms:.1}ms RTFx={:.1} {}", + wav_path.display(), audio_s / elapsed, clean(&transcript)); + if !ok { + eprintln!(" ref: {}", clean(&reference)); + eprintln!(" got: {}", clean(&transcript)); + } + if args.stats { + eprintln!(" {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", + audio_s / elapsed, elapsed_ms - nn_ms); + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); + } + } + } + + if args.no_details { + finalize_progress_bar(); } + eprintln!("{} {}/{} exact RTFx={:.1}", label, exact, wavs.len(), total_audio_s / total_elapsed_s); Ok(()) } From ccf3404d8108ed7a277a1f1a9d86817b14a7438f Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 10:45:06 +0000 Subject: [PATCH 12/20] nemo-parakeet-asr: add --write-gt and read pre-generated transcripts Add --write-gt flag to run the ground-truth beam decoder (beam_size=10, beam_dur_k=5) and write cleaned transcripts to .txt files beside each wav. The normal evaluation loop now reads those .txt files as reference instead of re-running the GT decoder on every pass, making parameter search faster. --- examples/nemo-parakeet-asr/src/main.rs | 42 +++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 72b9eed40d..71f0e8ad42 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -30,6 +30,9 @@ struct Args { stats: bool, #[arg(long)] no_details: bool, + /// Run ground-truth decoder and write transcript to a .txt file beside each wav + #[arg(long)] + write_gt: bool, #[arg(required = true)] inputs: Vec, } @@ -189,6 +192,35 @@ fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, De } } +fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Result<()> { + let gt_cfg = beam::BeamConfig { beam_size: 10, beam_dur_k: 5 }; + // Warmup + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + beam::transcribe_beam(model, &wav, >_cfg)?; + } + if no_details { + init_progress_bar_with_eta(wavs.len()); + set_progress_bar_action("Writing GT", Color::Blue, Style::Bold); + } + for wav_path in wavs { + let (wav, _) = load_wav(wav_path)?; + let (transcript, _) = beam::transcribe_beam(model, &wav, >_cfg)?; + let txt_path = wav_path.with_extension("txt"); + std::fs::write(&txt_path, clean(&transcript))?; + if no_details { + inc_progress_bar(); + } else { + eprintln!("{} -> {}", wav_path.display(), txt_path.display()); + } + } + if no_details { + finalize_progress_bar(); + } + eprintln!("wrote {} transcript(s)", wavs.len()); + Ok(()) +} + fn main() -> anyhow::Result<()> { let args = Args::parse(); let nnef = tract_rs::nnef()?.with_tract_core()?.with_tract_transformers()?; @@ -198,10 +230,12 @@ fn main() -> anyhow::Result<()> { .unwrap(); let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; - - let gt_cfg = beam::BeamConfig { beam_size: 10, beam_dur_k: 5 }; let wavs = collect_wavs(&args.inputs); + if args.write_gt { + return write_gt(&model, &wavs, args.no_details); + } + let label = decoder_label(&args); if !args.no_details { eprintln!("{} files={}", label, wavs.len()); @@ -213,7 +247,6 @@ fn main() -> anyhow::Result<()> { // Warmup on first file if let Some(first) = wavs.first() { let (wav, _) = load_wav(first)?; - beam::transcribe_beam(&model, &wav, >_cfg)?; run_decoder(&model, &wav, &args)?; } @@ -223,7 +256,8 @@ fn main() -> anyhow::Result<()> { for wav_path in &wavs { let (wav, audio_s) = load_wav(wav_path)?; - let (reference, _) = beam::transcribe_beam(&model, &wav, >_cfg)?; + let reference = std::fs::read_to_string(wav_path.with_extension("txt")) + .with_context(|| format!("no ground-truth transcript for {} (run with --write-gt first)", wav_path.display()))?; let t = Instant::now(); let (transcript, stats) = run_decoder(&model, &wav, &args)?; From 9d62445ffcca4721cf9a489b66b00a80173edbf9 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Mar 2026 11:26:51 +0000 Subject: [PATCH 13/20] nemo-parakeet-asr: add --param-search, migrate to indicatif Sweep 19 hardcoded decoder configs (greedy, beam, alsd variants) over all WAVs, printing EPR and RTFx as TSV to stdout for easy pasting into Sheets. Two stacked indicatif progress bars on stderr track configs and files; mp.suspend() prevents bar redraws from overwriting TSV rows. Also replaces progress_bar crate with indicatif for write_gt and the normal decode loop. --- examples/nemo-parakeet-asr/Cargo.toml | 2 +- examples/nemo-parakeet-asr/src/main.rs | 178 ++++++++++++++++++++++--- 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/examples/nemo-parakeet-asr/Cargo.toml b/examples/nemo-parakeet-asr/Cargo.toml index a8df64c339..26cc34caa4 100644 --- a/examples/nemo-parakeet-asr/Cargo.toml +++ b/examples/nemo-parakeet-asr/Cargo.toml @@ -8,7 +8,7 @@ anyhow.workspace = true clap = { version = "4", features = ["derive"] } float-ord.workspace = true hound = "3.5.1" -progress_bar = "1.4.0" +indicatif = "0.17" itertools.workspace = true serde_json.workspace = true tract-rs.workspace = true diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 71f0e8ad42..d205b99091 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -4,8 +4,8 @@ use std::time::Instant; use anyhow::*; use clap::Parser; -use progress_bar::*; use float_ord::FloatOrd; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use itertools::Itertools; use tract_rs::prelude::*; use tract_rs::Nnef; @@ -33,6 +33,9 @@ struct Args { /// Run ground-truth decoder and write transcript to a .txt file beside each wav #[arg(long)] write_gt: bool, + /// Sweep hardcoded decoder configs and print TSV results to stdout + #[arg(long)] + param_search: bool, #[arg(required = true)] inputs: Vec, } @@ -192,6 +195,131 @@ fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, De } } +// ─── SearchConfig ──────────────────────────────────────────────────────────── + +enum SearchConfig { + Greedy, + Beam(beam::BeamConfig), + Alsd(alsd::AlsdConfig), +} + +impl SearchConfig { + fn label(&self) -> String { + match self { + SearchConfig::Greedy => "greedy".to_owned(), + SearchConfig::Beam(c) => format!("beam_{}_{}", c.beam_size, c.beam_dur_k), + SearchConfig::Alsd(c) => format!("alsd_{}_{}_{}", c.alsd_beam_size, c.alsd_beam_dur_k, c.alsd_max_symbols_per_frame), + } + } + + fn run(&self, model: &TdtModel, wav: &[f32]) -> Result<(String, DecodingStats)> { + match self { + SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), + SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), + SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), + } + } +} + +fn search_configs() -> Vec { + fn b(beam_size: usize, beam_dur_k: usize) -> SearchConfig { + SearchConfig::Beam(beam::BeamConfig { beam_size, beam_dur_k }) + } + fn a(alsd_beam_size: usize, alsd_beam_dur_k: usize, alsd_max_symbols_per_frame: usize) -> SearchConfig { + SearchConfig::Alsd(alsd::AlsdConfig { alsd_beam_size, alsd_beam_dur_k, alsd_max_symbols_per_frame }) + } + vec![ + SearchConfig::Greedy, + b(1, 1), + b(2, 1), + b(2, 2), + b(4, 1), + b(4, 2), + b(4, 4), + b(8, 2), + b(8, 4), + a(1, 1, 10), + a(2, 1, 10), + a(2, 2, 10), + a(4, 1, 10), + a(4, 2, 10), + a(4, 4, 10), + a(8, 2, 10), + a(8, 4, 10), + a(4, 2, 3), + a(4, 2, 30), + ] +} + +fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { + // Warmup + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + greedy::transcribe_greedy(model, &wav)?; + } + + let configs = search_configs(); + + let mp = MultiProgress::new(); + let cfg_style = ProgressStyle::with_template( + "Configs {bar:40} {pos:>3}/{len} {msg}" + ).unwrap(); + let file_style = ProgressStyle::with_template( + "Files {bar:40} {pos:>3}/{len}" + ).unwrap(); + + let cfg_bar = mp.add(ProgressBar::new(configs.len() as u64)); + cfg_bar.set_style(cfg_style); + let file_bar = mp.add(ProgressBar::new(wavs.len() as u64)); + file_bar.set_style(file_style); + + mp.suspend(|| println!("label\tEPR\tRTFx")); + + for cfg in &configs { + let label = cfg.label(); + cfg_bar.set_message(label.clone()); + + file_bar.reset(); + file_bar.set_length(wavs.len() as u64); + + let mut total_audio_s = 0.0f64; + let mut total_elapsed_s = 0.0f64; + let mut exact = 0usize; + let mut total = 0usize; + + for wav_path in wavs { + let (wav, audio_s) = load_wav(wav_path)?; + let reference = std::fs::read_to_string(wav_path.with_extension("txt")) + .with_context(|| format!("no ground-truth transcript for {} (run with --write-gt first)", wav_path.display()))?; + let reference = reference.trim_end_matches('\n').to_owned(); + + let t = Instant::now(); + let (transcript, _) = cfg.run(model, &wav)?; + let elapsed = t.elapsed().as_secs_f64(); + + total_audio_s += audio_s; + total_elapsed_s += elapsed; + total += 1; + if clean(&transcript) == reference { exact += 1; } + + file_bar.inc(1); + } + + let epr = if total > 0 { exact as f64 / total as f64 } else { 0.0 }; + let rtfx = if total_elapsed_s > 0.0 { total_audio_s / total_elapsed_s } else { 0.0 }; + mp.suspend(|| println!("{}\t{:.4}\t{:.1}", label, epr, rtfx)); + + cfg_bar.inc(1); + } + + cfg_bar.finish(); + file_bar.finish(); + + Ok(()) +} + +// ─── write_gt ──────────────────────────────────────────────────────────────── + fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Result<()> { let gt_cfg = beam::BeamConfig { beam_size: 10, beam_dur_k: 5 }; // Warmup @@ -199,24 +327,27 @@ fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Res let (wav, _) = load_wav(first)?; beam::transcribe_beam(model, &wav, >_cfg)?; } - if no_details { - init_progress_bar_with_eta(wavs.len()); - set_progress_bar_action("Writing GT", Color::Blue, Style::Bold); - } + let pb = if no_details { + let pb = ProgressBar::new(wavs.len() as u64); + pb.set_style( + ProgressStyle::with_template("Writing GT {bar:40} {pos:>3}/{len}").unwrap() + ); + Some(pb) + } else { + None + }; for wav_path in wavs { let (wav, _) = load_wav(wav_path)?; let (transcript, _) = beam::transcribe_beam(model, &wav, >_cfg)?; let txt_path = wav_path.with_extension("txt"); std::fs::write(&txt_path, clean(&transcript))?; - if no_details { - inc_progress_bar(); + if let Some(ref pb) = pb { + pb.inc(1); } else { eprintln!("{} -> {}", wav_path.display(), txt_path.display()); } } - if no_details { - finalize_progress_bar(); - } + if let Some(pb) = pb { pb.finish(); } eprintln!("wrote {} transcript(s)", wavs.len()); Ok(()) } @@ -236,12 +367,13 @@ fn main() -> anyhow::Result<()> { return write_gt(&model, &wavs, args.no_details); } + if args.param_search { + return param_search(&model, &wavs); + } + let label = decoder_label(&args); if !args.no_details { eprintln!("{} files={}", label, wavs.len()); - } else { - init_progress_bar_with_eta(wavs.len()); - set_progress_bar_action("Decoding", Color::Blue, Style::Bold); } // Warmup on first file @@ -250,6 +382,16 @@ fn main() -> anyhow::Result<()> { run_decoder(&model, &wav, &args)?; } + let pb = if args.no_details { + let pb = ProgressBar::new(wavs.len() as u64); + pb.set_style( + ProgressStyle::with_template("Decoding {bar:40} {pos:>3}/{len}").unwrap() + ); + Some(pb) + } else { + None + }; + let mut total_audio_s = 0.0f64; let mut total_elapsed_s = 0.0f64; let mut exact = 0usize; @@ -266,11 +408,11 @@ fn main() -> anyhow::Result<()> { total_audio_s += audio_s; total_elapsed_s += elapsed; - let ok = transcript == reference; + let ok = clean(&transcript) == reference.trim_end_matches('\n'); if ok { exact += 1; } - if args.no_details { - inc_progress_bar(); + if let Some(ref pb) = pb { + pb.inc(1); } else { let mark = if ok { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; let elapsed_ms = elapsed * 1000.0; @@ -292,9 +434,7 @@ fn main() -> anyhow::Result<()> { } } - if args.no_details { - finalize_progress_bar(); - } + if let Some(pb) = pb { pb.finish(); } eprintln!("{} {}/{} exact RTFx={:.1}", label, exact, wavs.len(), total_audio_s / total_elapsed_s); Ok(()) From 733f7b6aa974922370e894f9f75f15a0fce5ac73 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 09:19:22 +0000 Subject: [PATCH 14/20] nemo-parakeet-asr: add per-model time profile columns to --param-search Output now includes pre%, enc%, dec%, joint%, host% columns showing each component's share of total wall time, making it easy to spot where time is spent across decoder configs. --- examples/nemo-parakeet-asr/src/main.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index d205b99091..c12374b801 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -44,7 +44,7 @@ struct Args { pub(crate) struct CallStats { calls: u32, total_batch: u64, - total_us: u64, + pub(crate) total_us: u64, } impl CallStats { @@ -273,7 +273,7 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { let file_bar = mp.add(ProgressBar::new(wavs.len() as u64)); file_bar.set_style(file_style); - mp.suspend(|| println!("label\tEPR\tRTFx")); + mp.suspend(|| println!("label\tEPR\tRTFx\tpre%\tenc%\tdec%\tjoint%\thost%")); for cfg in &configs { let label = cfg.label(); @@ -286,6 +286,10 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { let mut total_elapsed_s = 0.0f64; let mut exact = 0usize; let mut total = 0usize; + let mut pre_us = 0u64; + let mut enc_us = 0u64; + let mut dec_us = 0u64; + let mut joint_us = 0u64; for wav_path in wavs { let (wav, audio_s) = load_wav(wav_path)?; @@ -294,20 +298,30 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { let reference = reference.trim_end_matches('\n').to_owned(); let t = Instant::now(); - let (transcript, _) = cfg.run(model, &wav)?; + let (transcript, stats) = cfg.run(model, &wav)?; let elapsed = t.elapsed().as_secs_f64(); total_audio_s += audio_s; total_elapsed_s += elapsed; total += 1; if clean(&transcript) == reference { exact += 1; } + pre_us += stats.preprocessor.total_us; + enc_us += stats.encoder.total_us; + dec_us += stats.decoder.total_us; + joint_us += stats.joint.total_us; file_bar.inc(1); } let epr = if total > 0 { exact as f64 / total as f64 } else { 0.0 }; let rtfx = if total_elapsed_s > 0.0 { total_audio_s / total_elapsed_s } else { 0.0 }; - mp.suspend(|| println!("{}\t{:.4}\t{:.1}", label, epr, rtfx)); + let total_us = (total_elapsed_s * 1_000_000.0) as u64; + let pct = |us: u64| if total_us > 0 { us as f64 / total_us as f64 * 100.0 } else { 0.0 }; + let nn_us = pre_us + enc_us + dec_us + joint_us; + let host_us = total_us.saturating_sub(nn_us); + mp.suspend(|| println!("{}\t{:.4}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}", + label, epr, rtfx, + pct(pre_us), pct(enc_us), pct(dec_us), pct(joint_us), pct(host_us))); cfg_bar.inc(1); } From 1e4ca168caf69be700b28ac09ff948fa90b45434 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 09:20:47 +0000 Subject: [PATCH 15/20] nemo-parakeet-asr: rename ALSD decoder to FBSD The decoder is not a faithful implementation of alignment-length synchronous decoding (ALSD): it does not enforce the alignment-length synchronization invariant, and uses a per-frame symbol cap from TSD. Rename to FBSD (Frame-asynchronous Beam Search Decoding) to reflect what it actually does. --- .../src/{alsd.rs => fbsd.rs} | 40 +++++++++---------- examples/nemo-parakeet-asr/src/main.rs | 20 +++++----- 2 files changed, 30 insertions(+), 30 deletions(-) rename examples/nemo-parakeet-asr/src/{alsd.rs => fbsd.rs} (91%) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/fbsd.rs similarity index 91% rename from examples/nemo-parakeet-asr/src/alsd.rs rename to examples/nemo-parakeet-asr/src/fbsd.rs index f9853f3060..ba836a398d 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/fbsd.rs @@ -8,21 +8,21 @@ use tract_rs::prelude::tract_ndarray::prelude::*; use tract_rs::prelude::*; #[derive(Args, Clone)] -pub struct AlsdConfig { - /// Beam width for ALSD decoding +pub struct FbsdConfig { + /// Beam width for FBSD decoding #[arg(long, default_value_t = 4)] - pub alsd_beam_size: usize, + pub fbsd_beam_size: usize, - /// Duration candidates per hypothesis in ALSD + /// Duration candidates per hypothesis in FBSD #[arg(long, default_value_t = 2)] - pub alsd_beam_dur_k: usize, + pub fbsd_beam_dur_k: usize, /// Max non-blank tokens emitted per frame per hypothesis #[arg(long, default_value_t = 10)] - pub alsd_max_symbols_per_frame: usize, + pub fbsd_max_symbols_per_frame: usize, } -struct AlsdHyp { +struct FbsdHyp { score: f32, tokens: Vec, current_frame: usize, @@ -32,10 +32,10 @@ struct AlsdHyp { state_1: Value, } -pub fn transcribe_alsd( +pub fn transcribe_fbsd( model: &crate::TdtModel, wav: &[f32], - cfg: &AlsdConfig, + cfg: &FbsdConfig, ) -> Result<(String, crate::DecodingStats)> { let mut stats = crate::DecodingStats::default(); @@ -65,7 +65,7 @@ pub fn transcribe_alsd( model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); stats.decoder.record(batch, t.elapsed()); - let mut beam: Vec = vec![AlsdHyp { + let mut beam: Vec = vec![FbsdHyp { score: 0.0, tokens: vec![], current_frame: 0, @@ -77,8 +77,8 @@ pub fn transcribe_alsd( loop { // Split into active (still have frames to consume) and completed - let mut active: Vec = Vec::new(); - let mut completed: Vec = Vec::new(); + let mut active: Vec = Vec::new(); + let mut completed: Vec = Vec::new(); for h in beam.drain(..) { if h.current_frame < max_frames { active.push(h); @@ -117,7 +117,7 @@ pub fn transcribe_alsd( stats.joint.record(b, t.elapsed()); // 2. Per-hyp: blank+duration → next; non-blank → expansion list - let mut next: Vec = completed; + let mut next: Vec = completed; let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); { let logits_arr = logits_b.view::()?; // [b, vocab+dur] @@ -131,9 +131,9 @@ pub fn transcribe_alsd( let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ds.truncate(cfg.alsd_beam_dur_k); + ds.truncate(cfg.fbsd_beam_dur_k); for (di, dlp) in ds { - next.push(AlsdHyp { + next.push(FbsdHyp { score: active[bi].score + log_probs[model.blank_id] + dlp, tokens: active[bi].tokens.clone(), current_frame: active[bi].current_frame + di, @@ -145,11 +145,11 @@ pub fn transcribe_alsd( } // Non-blank expansions stay on the same frame (symbol count checked) - if active[bi].symbols_this_frame < cfg.alsd_max_symbols_per_frame { + if active[bi].symbols_this_frame < cfg.fbsd_max_symbols_per_frame { let mut ts: Vec<(usize, f32)> = (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); - ts.truncate(cfg.alsd_beam_size); + ts.truncate(cfg.fbsd_beam_size); per_hyp_token_scores.push(ts); } else { per_hyp_token_scores.push(vec![]); @@ -213,7 +213,7 @@ pub fn transcribe_alsd( s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; let mut new_tokens = active[bi].tokens.clone(); new_tokens.push(ti); - next.push(AlsdHyp { + next.push(FbsdHyp { score: active[bi].score + lp, tokens: new_tokens, current_frame: active[bi].current_frame, @@ -227,13 +227,13 @@ pub fn transcribe_alsd( } } - // 5. Fuse duplicates then prune to alsd_beam_size. + // 5. Fuse duplicates then prune to fbsd_beam_size. // Key is (tokens, current_frame, symbols_this_frame): same tokens → same decoder // state; same frame → same encoder embedding; same symbol count → same gate behavior. next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); let mut seen = std::collections::HashSet::<(Vec, usize, usize)>::new(); next.retain(|h| seen.insert((h.tokens.clone(), h.current_frame, h.symbols_this_frame))); - next.truncate(cfg.alsd_beam_size); + next.truncate(cfg.fbsd_beam_size); beam = next; } diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index c12374b801..b3b7389df0 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -12,10 +12,10 @@ use tract_rs::Nnef; mod greedy; mod beam; -mod alsd; +mod fbsd; #[derive(clap::ValueEnum, Clone)] -enum Decoder { Greedy, Beam, Alsd } +enum Decoder { Greedy, Beam, Fbsd } #[derive(Parser)] #[command(about = "NeMo Parakeet ASR inference")] @@ -23,7 +23,7 @@ struct Args { #[command(flatten)] beam: beam::BeamConfig, #[command(flatten)] - alsd: alsd::AlsdConfig, + fbsd: fbsd::FbsdConfig, #[arg(long, value_enum, default_value_t = Decoder::Greedy)] decoder: Decoder, #[arg(long)] @@ -183,7 +183,7 @@ fn decoder_label(args: &Args) -> String { match args.decoder { Decoder::Greedy => "greedy".to_string(), Decoder::Beam => format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k), - Decoder::Alsd => format!("alsd beam_size={} beam_dur_k={} max_symbols={}", args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_max_symbols_per_frame), + Decoder::Fbsd => format!("fbsd beam_size={} beam_dur_k={} max_symbols={}", args.fbsd.fbsd_beam_size, args.fbsd.fbsd_beam_dur_k, args.fbsd.fbsd_max_symbols_per_frame), } } @@ -191,7 +191,7 @@ fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, De match args.decoder { Decoder::Greedy => greedy::transcribe_greedy(model, wav), Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), - Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), + Decoder::Fbsd => fbsd::transcribe_fbsd(model, wav, &args.fbsd), } } @@ -200,7 +200,7 @@ fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, De enum SearchConfig { Greedy, Beam(beam::BeamConfig), - Alsd(alsd::AlsdConfig), + Fbsd(fbsd::FbsdConfig), } impl SearchConfig { @@ -208,7 +208,7 @@ impl SearchConfig { match self { SearchConfig::Greedy => "greedy".to_owned(), SearchConfig::Beam(c) => format!("beam_{}_{}", c.beam_size, c.beam_dur_k), - SearchConfig::Alsd(c) => format!("alsd_{}_{}_{}", c.alsd_beam_size, c.alsd_beam_dur_k, c.alsd_max_symbols_per_frame), + SearchConfig::Fbsd(c) => format!("fbsd_{}_{}_{}", c.fbsd_beam_size, c.fbsd_beam_dur_k, c.fbsd_max_symbols_per_frame), } } @@ -216,7 +216,7 @@ impl SearchConfig { match self { SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), - SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), + SearchConfig::Fbsd(c) => fbsd::transcribe_fbsd(model, wav, c), } } } @@ -225,8 +225,8 @@ fn search_configs() -> Vec { fn b(beam_size: usize, beam_dur_k: usize) -> SearchConfig { SearchConfig::Beam(beam::BeamConfig { beam_size, beam_dur_k }) } - fn a(alsd_beam_size: usize, alsd_beam_dur_k: usize, alsd_max_symbols_per_frame: usize) -> SearchConfig { - SearchConfig::Alsd(alsd::AlsdConfig { alsd_beam_size, alsd_beam_dur_k, alsd_max_symbols_per_frame }) + fn a(beam_size: usize, beam_dur_k: usize, max_symbols: usize) -> SearchConfig { + SearchConfig::Fbsd(fbsd::FbsdConfig { fbsd_beam_size: beam_size, fbsd_beam_dur_k: beam_dur_k, fbsd_max_symbols_per_frame: max_symbols }) } vec![ SearchConfig::Greedy, From 0af03b07f508a6697029581cc3bbed115e13bfdf Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 09:51:40 +0000 Subject: [PATCH 16/20] nemo-parakeet-asr: implement real ALSD decoder for TDT The published ALSD algorithm (Saon et al., ICASSP 2020) iterates over alignment steps rather than frames. All hypotheses in the beam advance by exactly one alignment event per step (either one token or one blank). Completed hypotheses (current_frame >= T) are drained to a final list at each step; the loop terminates after T + U_max steps or when the beam empties, whichever comes first. Key differences from FBSD: - Outer loop bounded by T + u_max alignment steps (not open-ended) - No per-frame symbol cap; the step bound naturally limits token chains - Finer-grained pruning: after every single token or blank emission - Dedup key is (tokens, frame); no symbols_this_frame dimension Also adds alsd_* configs to --param-search (beam_size x beam_dur_k grid, u_max=50) and exposes --decoder alsd for interactive use. --- examples/nemo-parakeet-asr/src/alsd.rs | 244 +++++++++++++++++++++++++ examples/nemo-parakeet-asr/src/main.rs | 43 +++-- 2 files changed, 275 insertions(+), 12 deletions(-) create mode 100644 examples/nemo-parakeet-asr/src/alsd.rs diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs new file mode 100644 index 0000000000..9d162cc979 --- /dev/null +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -0,0 +1,244 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct AlsdConfig { + /// Beam width for ALSD decoding + #[arg(long, default_value_t = 4)] + pub alsd_beam_size: usize, + + /// Duration candidates per hypothesis in ALSD + #[arg(long, default_value_t = 2)] + pub alsd_beam_dur_k: usize, + + /// Maximum expected output token count (U_max in the paper) + #[arg(long, default_value_t = 50)] + pub alsd_u_max: usize, +} + +struct AlsdHyp { + score: f32, + tokens: Vec, + current_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_alsd( + model: &crate::TdtModel, + wav: &[f32], + cfg: &AlsdConfig, +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = + model.preprocessor.run([samples, len])?.try_into().unwrap(); + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = + model.encoder.run([features, feat_len])?.try_into().unwrap(); + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = + model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + stats.decoder.record(batch, t.elapsed()); + + let mut beam: Vec = vec![AlsdHyp { + score: 0.0, + tokens: vec![], + current_frame: 0, + dec_out, + state_0, + state_1, + }]; + + let mut final_hyps: Vec = Vec::new(); + + // Outer loop: alignment steps 0 .. T + U_max. + // All hypotheses in `beam` are at the same alignment step. + // At each step every active hyp (current_frame < T) gets one expansion + // (blank+duration OR one non-blank token), so a single step advances + // the alignment length by exactly 1. + for _step in 0..(max_frames + cfg.alsd_u_max) { + // Separate completed hypotheses (exhausted all frames) from active ones. + let mut active: Vec = Vec::new(); + for h in beam.drain(..) { + if h.current_frame >= max_frames { + final_hyps.push(h); + } else { + active.push(h); + } + } + if active.is_empty() { + break; + } + + let b = active.len(); + + // 1. JOINT: batched over all active hypotheses (each at its own frame). + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(bi, e, _)| { + enc_arr[[0, e, active[bi].current_frame]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let [logits_b] = + model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + stats.joint.record(b, t.elapsed()); + + // 2. Per-hyp: blank+duration children (no decoder update) and + // non-blank token candidates (need decoder update). + let mut next: Vec = Vec::new(); + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + // Blank + top-k duration expansions (decoder state unchanged). + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.alsd_beam_dur_k); + for (di, dlp) in ds { + next.push(AlsdHyp { + score: active[bi].score + log_probs[model.blank_id] + dlp, + tokens: active[bi].tokens.clone(), + current_frame: active[bi].current_frame + di, + dec_out: active[bi].dec_out.clone(), + state_0: active[bi].state_0.clone(), + state_1: active[bi].state_1.clone(), + }); + } + + // Top-k non-blank token candidates (decoder step needed). + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.alsd_beam_size); + per_hyp_token_scores.push(ts); + } + } + + // 3. DECODER: single batched call over all non-blank expansions. + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + if n > 0 { + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = + model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + stats.decoder.record(n, t.elapsed()); + + let dec_arr = dec_out_b.view::()?; + let s0_arr = s0_b.view::()?; + let s1_arr = s1_b.view::()?; + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = active[bi].tokens.clone(); + new_tokens.push(ti); + next.push(AlsdHyp { + score: active[bi].score + lp, + tokens: new_tokens, + current_frame: active[bi].current_frame, // frame unchanged for tokens + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + } + + // 4. Prune: sort by score, deduplicate on (tokens, frame), keep top beam_size. + // Two hyps with identical tokens and frame have identical decoder state; + // keep only the highest-scoring one. + next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize)>::new(); + next.retain(|h| seen.insert((h.tokens.clone(), h.current_frame))); + next.truncate(cfg.alsd_beam_size); + beam = next; + } + + // Any hyps remaining in beam that didn't finish (step limit hit) go to final. + final_hyps.extend(beam); + + let best = final_hyps + .into_iter() + .max_by_key(|h| FloatOrd(h.score)) + .ok_or_else(|| anyhow!("no hypotheses survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index b3b7389df0..2a224fcbd1 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -13,9 +13,10 @@ use tract_rs::Nnef; mod greedy; mod beam; mod fbsd; +mod alsd; #[derive(clap::ValueEnum, Clone)] -enum Decoder { Greedy, Beam, Fbsd } +enum Decoder { Greedy, Beam, Fbsd, Alsd } #[derive(Parser)] #[command(about = "NeMo Parakeet ASR inference")] @@ -24,6 +25,8 @@ struct Args { beam: beam::BeamConfig, #[command(flatten)] fbsd: fbsd::FbsdConfig, + #[command(flatten)] + alsd: alsd::AlsdConfig, #[arg(long, value_enum, default_value_t = Decoder::Greedy)] decoder: Decoder, #[arg(long)] @@ -184,6 +187,7 @@ fn decoder_label(args: &Args) -> String { Decoder::Greedy => "greedy".to_string(), Decoder::Beam => format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k), Decoder::Fbsd => format!("fbsd beam_size={} beam_dur_k={} max_symbols={}", args.fbsd.fbsd_beam_size, args.fbsd.fbsd_beam_dur_k, args.fbsd.fbsd_max_symbols_per_frame), + Decoder::Alsd => format!("alsd beam_size={} beam_dur_k={} u_max={}", args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_u_max), } } @@ -192,6 +196,7 @@ fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, De Decoder::Greedy => greedy::transcribe_greedy(model, wav), Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), Decoder::Fbsd => fbsd::transcribe_fbsd(model, wav, &args.fbsd), + Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), } } @@ -201,6 +206,7 @@ enum SearchConfig { Greedy, Beam(beam::BeamConfig), Fbsd(fbsd::FbsdConfig), + Alsd(alsd::AlsdConfig), } impl SearchConfig { @@ -209,6 +215,7 @@ impl SearchConfig { SearchConfig::Greedy => "greedy".to_owned(), SearchConfig::Beam(c) => format!("beam_{}_{}", c.beam_size, c.beam_dur_k), SearchConfig::Fbsd(c) => format!("fbsd_{}_{}_{}", c.fbsd_beam_size, c.fbsd_beam_dur_k, c.fbsd_max_symbols_per_frame), + SearchConfig::Alsd(c) => format!("alsd_{}_{}", c.alsd_beam_size, c.alsd_beam_dur_k), } } @@ -217,6 +224,7 @@ impl SearchConfig { SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), SearchConfig::Fbsd(c) => fbsd::transcribe_fbsd(model, wav, c), + SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), } } } @@ -225,9 +233,12 @@ fn search_configs() -> Vec { fn b(beam_size: usize, beam_dur_k: usize) -> SearchConfig { SearchConfig::Beam(beam::BeamConfig { beam_size, beam_dur_k }) } - fn a(beam_size: usize, beam_dur_k: usize, max_symbols: usize) -> SearchConfig { + fn f(beam_size: usize, beam_dur_k: usize, max_symbols: usize) -> SearchConfig { SearchConfig::Fbsd(fbsd::FbsdConfig { fbsd_beam_size: beam_size, fbsd_beam_dur_k: beam_dur_k, fbsd_max_symbols_per_frame: max_symbols }) } + fn a(beam_size: usize, beam_dur_k: usize) -> SearchConfig { + SearchConfig::Alsd(alsd::AlsdConfig { alsd_beam_size: beam_size, alsd_beam_dur_k: beam_dur_k, alsd_u_max: 50 }) + } vec![ SearchConfig::Greedy, b(1, 1), @@ -238,16 +249,24 @@ fn search_configs() -> Vec { b(4, 4), b(8, 2), b(8, 4), - a(1, 1, 10), - a(2, 1, 10), - a(2, 2, 10), - a(4, 1, 10), - a(4, 2, 10), - a(4, 4, 10), - a(8, 2, 10), - a(8, 4, 10), - a(4, 2, 3), - a(4, 2, 30), + f(1, 1, 10), + f(2, 1, 10), + f(2, 2, 10), + f(4, 1, 10), + f(4, 2, 10), + f(4, 4, 10), + f(8, 2, 10), + f(8, 4, 10), + f(4, 2, 3), + f(4, 2, 30), + a(1, 1), + a(2, 1), + a(2, 2), + a(4, 1), + a(4, 2), + a(4, 4), + a(8, 2), + a(8, 4), ] } From 4d43f5221a57058a53cf2ea506ba27b4d354d575 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 10:53:32 +0000 Subject: [PATCH 17/20] nemo-parakeet-asr: apply non-blank token duration in all four decoders TDT predicts a duration for every step regardless of whether the token is blank or non-blank. All decoders were ignoring the duration for non-blank emissions, effectively treating every token as d=0. Fix: - greedy: advance frame_ix by argmax(dur) after token emission - beam/fbsd/alsd: compute best_dur = argmax(dur_log_probs) per hyp; add dur_log_probs[best_dur] to each non-blank child's score and advance current_frame by best_dur - fbsd: reset symbols_this_frame to 0 when best_dur > 0 (frame changed) --- examples/nemo-parakeet-asr/src/alsd.rs | 10 ++++++++-- examples/nemo-parakeet-asr/src/beam.rs | 10 ++++++++-- examples/nemo-parakeet-asr/src/fbsd.rs | 14 ++++++++++---- examples/nemo-parakeet-asr/src/greedy.rs | 4 +++- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index 9d162cc979..4f40b3d3f0 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -122,6 +122,7 @@ pub fn transcribe_alsd( // non-blank token candidates (need decoder update). let mut next: Vec = Vec::new(); let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); { let logits_arr = logits_b.view::()?; for bi in 0..b { @@ -130,6 +131,10 @@ pub fn transcribe_alsd( let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + let best_dur = dur_log_probs.iter().enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + // Blank + top-k duration expansions (decoder state unchanged). let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); @@ -201,6 +206,7 @@ pub fn transcribe_alsd( let s1_arr = s1_b.view::()?; let mut i = 0; for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; for &(ti, lp) in ts { let new_dec_out: Value = dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; @@ -211,9 +217,9 @@ pub fn transcribe_alsd( let mut new_tokens = active[bi].tokens.clone(); new_tokens.push(ti); next.push(AlsdHyp { - score: active[bi].score + lp, + score: active[bi].score + lp + best_dur_lp, tokens: new_tokens, - current_frame: active[bi].current_frame, // frame unchanged for tokens + current_frame: active[bi].current_frame + best_dur, dec_out: new_dec_out, state_0: new_s0, state_1: new_s1, diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index 2be4b4e0ee..271a32a463 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -107,6 +107,7 @@ pub fn transcribe_beam( // 2. Per-hyp: token scores + duration expansions into kept let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); { let logits_arr = logits_b.view::()?; // [b, vocab+dur] for bi in 0..b { @@ -121,6 +122,10 @@ pub fn transcribe_beam( ts.truncate(cfg.beam_size); per_hyp_token_scores.push(ts); + let best_dur = dur_log_probs.iter().enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); @@ -186,6 +191,7 @@ pub fn transcribe_beam( let mut out = Vec::with_capacity(n); let mut i = 0; for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; for &(ti, lp) in ts { let new_dec_out: Value = dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; @@ -196,9 +202,9 @@ pub fn transcribe_beam( let mut new_tokens = hyps[bi].tokens.clone(); new_tokens.push(ti); out.push(Beam { - score: hyps[bi].score + lp, + score: hyps[bi].score + lp + best_dur_lp, tokens: new_tokens, - last_frame: frame_ix, + last_frame: frame_ix + best_dur, dec_out: new_dec_out, state_0: new_s0, state_1: new_s1, diff --git a/examples/nemo-parakeet-asr/src/fbsd.rs b/examples/nemo-parakeet-asr/src/fbsd.rs index ba836a398d..59e3a8cf3a 100644 --- a/examples/nemo-parakeet-asr/src/fbsd.rs +++ b/examples/nemo-parakeet-asr/src/fbsd.rs @@ -119,6 +119,7 @@ pub fn transcribe_fbsd( // 2. Per-hyp: blank+duration → next; non-blank → expansion list let mut next: Vec = completed; let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); { let logits_arr = logits_b.view::()?; // [b, vocab+dur] for bi in 0..b { @@ -127,6 +128,10 @@ pub fn transcribe_fbsd( let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + let best_dur = dur_log_probs.iter().enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + // Blank + duration expansions advance the frame and reset symbol count let mut ds: Vec<(usize, f32)> = (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); @@ -144,7 +149,7 @@ pub fn transcribe_fbsd( }); } - // Non-blank expansions stay on the same frame (symbol count checked) + // Non-blank expansions: gated by per-frame symbol cap if active[bi].symbols_this_frame < cfg.fbsd_max_symbols_per_frame { let mut ts: Vec<(usize, f32)> = (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); @@ -204,6 +209,7 @@ pub fn transcribe_fbsd( let s1_arr = s1_b.view::()?; // [2, N, 640] let mut i = 0; for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; for &(ti, lp) in ts { let new_dec_out: Value = dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; @@ -214,10 +220,10 @@ pub fn transcribe_fbsd( let mut new_tokens = active[bi].tokens.clone(); new_tokens.push(ti); next.push(FbsdHyp { - score: active[bi].score + lp, + score: active[bi].score + lp + best_dur_lp, tokens: new_tokens, - current_frame: active[bi].current_frame, - symbols_this_frame: active[bi].symbols_this_frame + 1, + current_frame: active[bi].current_frame + best_dur, + symbols_this_frame: if best_dur > 0 { 0 } else { active[bi].symbols_this_frame + 1 }, dec_out: new_dec_out, state_0: new_s0, state_1: new_s1, diff --git a/examples/nemo-parakeet-asr/src/greedy.rs b/examples/nemo-parakeet-asr/src/greedy.rs index 2e1521227a..9b1b26a3dd 100644 --- a/examples/nemo-parakeet-asr/src/greedy.rs +++ b/examples/nemo-parakeet-asr/src/greedy.rs @@ -48,10 +48,12 @@ pub fn transcribe_greedy( let logits = logits.view::()?; let logits = logits.as_slice().unwrap(); let token_id = crate::argmax(&logits[0..model.blank_id + 1]).unwrap(); + let dur = crate::argmax(&logits[model.blank_id + 1..]).unwrap_or(0); if token_id == model.blank_id { - frame_ix += crate::argmax(&logits[model.blank_id + 1..]).unwrap_or(0).max(1); + frame_ix += dur.max(1); } else { hyp.push(token_id); + frame_ix += dur; token = Value::from_slice(&[1, 1], &[token_id as i32])?; let t = Instant::now(); [token, state_0, state_1] = From c4808c54e10f31d2098a1f4d03857d86e4a358a8 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 12:35:38 +0000 Subject: [PATCH 18/20] nemo-parakeet-asr: initialize decoder with blank_id as SOS token NeMo sets _SOS = blank_index. With blank_as_pad=True the blank token maps to a zero-vector embedding, giving a neutral start. All four decoders were passing token 0 () instead, feeding a learned non-zero embedding at the first decoder step. --- examples/nemo-parakeet-asr/src/alsd.rs | 2 +- examples/nemo-parakeet-asr/src/beam.rs | 2 +- examples/nemo-parakeet-asr/src/fbsd.rs | 2 +- examples/nemo-parakeet-asr/src/greedy.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index 4f40b3d3f0..cfe608b548 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -56,7 +56,7 @@ pub fn transcribe_alsd( let max_frames = encoded.shape()[2]; let enc_dim = encoded.shape()[1]; - let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index 271a32a463..2972736ba2 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -52,7 +52,7 @@ pub fn transcribe_beam( let max_frames = encoded.shape()[2]; let enc_dim = encoded.shape()[1]; - let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); diff --git a/examples/nemo-parakeet-asr/src/fbsd.rs b/examples/nemo-parakeet-asr/src/fbsd.rs index 59e3a8cf3a..f1e1b36ff3 100644 --- a/examples/nemo-parakeet-asr/src/fbsd.rs +++ b/examples/nemo-parakeet-asr/src/fbsd.rs @@ -57,7 +57,7 @@ pub fn transcribe_fbsd( let max_frames = encoded.shape()[2]; let enc_dim = encoded.shape()[1]; - let init_token = Value::from_slice(&[1, 1], &[0i32])?; + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); diff --git a/examples/nemo-parakeet-asr/src/greedy.rs b/examples/nemo-parakeet-asr/src/greedy.rs index 9b1b26a3dd..e4deee4045 100644 --- a/examples/nemo-parakeet-asr/src/greedy.rs +++ b/examples/nemo-parakeet-asr/src/greedy.rs @@ -30,7 +30,7 @@ pub fn transcribe_greedy( let mut hyp = vec![]; let mut frame_ix = 0; - let mut token = Value::from_slice(&[1, 1], &[0i32])?; + let mut token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; From 842776965f7df2dd764e90a0250588c8f18574ef Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 13:41:13 +0000 Subject: [PATCH 19/20] nemo-parakeet-asr: wrap submodel runs in named inline(never) methods for profiling Makes TdtModel Runnables private; exposes run_preprocessor, run_encoder, run_decoder, run_joint as #[inline(never)] pub(crate) methods so profiler stack traces show named frames instead of anonymous run() calls. --- examples/nemo-parakeet-asr/src/alsd.rs | 15 +- examples/nemo-parakeet-asr/src/beam.rs | 15 +- examples/nemo-parakeet-asr/src/fbsd.rs | 15 +- examples/nemo-parakeet-asr/src/greedy.rs | 14 +- examples/nemo-parakeet-asr/src/main.rs | 202 ++++++++++++++++------- 5 files changed, 162 insertions(+), 99 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index cfe608b548..a40abb0ac0 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -42,13 +42,11 @@ pub fn transcribe_alsd( let len: Value = arr1(&[wav.len() as i64]).try_into()?; let t = Instant::now(); - let [features, feat_len] = - model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; stats.preprocessor.record(1, t.elapsed()); let t = Instant::now(); - let [encoded, _lens] = - model.encoder.run([features, feat_len])?.try_into().unwrap(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); @@ -60,8 +58,7 @@ pub fn transcribe_alsd( let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); - let [dec_out, state_0, state_1] = - model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; stats.decoder.record(batch, t.elapsed()); let mut beam: Vec = vec![AlsdHyp { @@ -114,8 +111,7 @@ pub fn transcribe_alsd( .try_into()? }; let t = Instant::now(); - let [logits_b] = - model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; stats.joint.record(b, t.elapsed()); // 2. Per-hyp: blank+duration children (no decoder update) and @@ -197,8 +193,7 @@ pub fn transcribe_alsd( }; let t = Instant::now(); - let [dec_out_b, s0_b, s1_b] = - model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; stats.decoder.record(n, t.elapsed()); let dec_arr = dec_out_b.view::()?; diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs index 2972736ba2..4ad2d367dc 100644 --- a/examples/nemo-parakeet-asr/src/beam.rs +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -38,13 +38,11 @@ pub fn transcribe_beam( let len: Value = arr1(&[wav.len() as i64]).try_into()?; let t = Instant::now(); - let [features, feat_len] = - model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; stats.preprocessor.record(1, t.elapsed()); let t = Instant::now(); - let [encoded, _lens] = - model.encoder.run([features, feat_len])?.try_into().unwrap(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); @@ -56,8 +54,7 @@ pub fn transcribe_beam( let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); - let [dec_out, state_0, state_1] = - model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; stats.decoder.record(batch, t.elapsed()); let mut all_beams: Vec = vec![Beam { @@ -101,8 +98,7 @@ pub fn transcribe_beam( .try_into()? }; let t = Instant::now(); - let [logits_b] = - model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; stats.joint.record(b, t.elapsed()); // 2. Per-hyp: token scores + duration expansions into kept @@ -179,8 +175,7 @@ pub fn transcribe_beam( }; let t = Instant::now(); - let [dec_out_b, s0_b, s1_b] = - model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; stats.decoder.record(n, t.elapsed()); // 4. Slice and build new beams diff --git a/examples/nemo-parakeet-asr/src/fbsd.rs b/examples/nemo-parakeet-asr/src/fbsd.rs index f1e1b36ff3..383e162693 100644 --- a/examples/nemo-parakeet-asr/src/fbsd.rs +++ b/examples/nemo-parakeet-asr/src/fbsd.rs @@ -43,13 +43,11 @@ pub fn transcribe_fbsd( let len: Value = arr1(&[wav.len() as i64]).try_into()?; let t = Instant::now(); - let [features, feat_len] = - model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; stats.preprocessor.record(1, t.elapsed()); let t = Instant::now(); - let [encoded, _lens] = - model.encoder.run([features, feat_len])?.try_into().unwrap(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); @@ -61,8 +59,7 @@ pub fn transcribe_fbsd( let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); - let [dec_out, state_0, state_1] = - model.decoder.run([init_token, init_s0, init_s1])?.try_into().unwrap(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; stats.decoder.record(batch, t.elapsed()); let mut beam: Vec = vec![FbsdHyp { @@ -112,8 +109,7 @@ pub fn transcribe_fbsd( .try_into()? }; let t = Instant::now(); - let [logits_b] = - model.joint.run([frame_batch, dec_out_batch])?.try_into().unwrap(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; stats.joint.record(b, t.elapsed()); // 2. Per-hyp: blank+duration → next; non-blank → expansion list @@ -199,8 +195,7 @@ pub fn transcribe_fbsd( }; let t = Instant::now(); - let [dec_out_b, s0_b, s1_b] = - model.decoder.run([tokens_batch, s0_batch, s1_batch])?.try_into().unwrap(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; stats.decoder.record(n, t.elapsed()); // 4. Slice per-expansion outputs and push into next diff --git a/examples/nemo-parakeet-asr/src/greedy.rs b/examples/nemo-parakeet-asr/src/greedy.rs index e4deee4045..97c3e0896c 100644 --- a/examples/nemo-parakeet-asr/src/greedy.rs +++ b/examples/nemo-parakeet-asr/src/greedy.rs @@ -15,13 +15,11 @@ pub fn transcribe_greedy( let len: Value = arr1(&[wav.len() as i64]).try_into()?; let t = Instant::now(); - let [features, feat_len] = - model.preprocessor.run([samples, len])?.try_into().unwrap(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; stats.preprocessor.record(1, t.elapsed()); let t = Instant::now(); - let [encoded, _lens] = - model.encoder.run([features, feat_len])?.try_into().unwrap(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; stats.encoder.record(1, t.elapsed()); let encoded: ArrayD = encoded.view()?.into_owned(); @@ -35,15 +33,14 @@ pub fn transcribe_greedy( let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; let t = Instant::now(); - [token, state_0, state_1] = - model.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + [token, state_0, state_1] = model.run_decoder(token, state_0, state_1)?; stats.decoder.record(1, t.elapsed()); while hyp.len() < max_len && frame_ix < max_frames { let frame: Value = encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; let t = Instant::now(); - let [logits] = model.joint.run([frame, token.clone()])?.try_into().unwrap(); + let logits = model.run_joint(frame, token.clone())?; stats.joint.record(1, t.elapsed()); let logits = logits.view::()?; let logits = logits.as_slice().unwrap(); @@ -56,8 +53,7 @@ pub fn transcribe_greedy( frame_ix += dur; token = Value::from_slice(&[1, 1], &[token_id as i32])?; let t = Instant::now(); - [token, state_0, state_1] = - model.decoder.run([token, state_0, state_1])?.try_into().unwrap(); + [token, state_0, state_1] = model.run_decoder(token, state_0, state_1)?; stats.decoder.record(1, t.elapsed()); } } diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 2a224fcbd1..bdc8a8118b 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -10,13 +10,18 @@ use itertools::Itertools; use tract_rs::prelude::*; use tract_rs::Nnef; -mod greedy; +mod alsd; mod beam; mod fbsd; -mod alsd; +mod greedy; #[derive(clap::ValueEnum, Clone)] -enum Decoder { Greedy, Beam, Fbsd, Alsd } +enum Decoder { + Greedy, + Beam, + Fbsd, + Alsd, +} #[derive(Parser)] #[command(about = "NeMo Parakeet ASR inference")] @@ -105,35 +110,60 @@ pub(crate) fn log_softmax(xs: &[f32]) -> Vec { } pub(crate) struct TdtModel { - pub(crate) preprocessor: Runnable, - pub(crate) encoder: Runnable, - pub(crate) decoder: Runnable, - pub(crate) joint: Runnable, + preprocessor: Runnable, + encoder: Runnable, + decoder: Runnable, + joint: Runnable, pub(crate) vocab: Vec, pub(crate) blank_id: usize, } impl TdtModel { + #[inline(never)] + pub(crate) fn run_preprocessor(&self, samples: Value, len: Value) -> Result<[Value; 2]> { + Ok(self.preprocessor.run([samples, len])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_encoder(&self, features: Value, feat_len: Value) -> Result<[Value; 2]> { + Ok(self.encoder.run([features, feat_len])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_decoder(&self, token: Value, s0: Value, s1: Value) -> Result<[Value; 3]> { + Ok(self.decoder.run([token, s0, s1])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_joint(&self, frame: Value, dec_out: Value) -> Result { + let [logits] = self.joint.run([frame, dec_out])?.try_into().unwrap(); + Ok(logits) + } + fn load(model_dir: impl AsRef, nnef: &Nnef, gpu: &Runtime) -> Result { let model_dir = model_dir.as_ref(); let config: serde_json::Value = serde_json::from_reader(File::open(model_dir.join("model_config.json"))?)?; - let blank_id = - config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; + let blank_id = config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; let vocab = config.pointer("/joint/vocabulary").unwrap().as_array().unwrap(); let vocab: Vec = vocab.iter().map(|v| v.as_str().unwrap().to_owned()).collect(); - let preprocessor = - nnef.load(model_dir.join("preprocessor.nnef.tgz"))?.into_runnable()?; + let mut preprocessor = nnef.load(model_dir.join("preprocessor.nnef.tgz"))?; + preprocessor.transform("transformers-detect-all")?; + let preprocessor = gpu.prepare(preprocessor)?; let mut encoder = nnef.load(model_dir.join("encoder.nnef.tgz"))?; encoder.transform("transformers-detect-all")?; let encoder = gpu.prepare(encoder)?; - let decoder = nnef.load(model_dir.join("decoder.nnef.tgz"))?; + let mut decoder = nnef.load(model_dir.join("decoder.nnef.tgz"))?; + decoder.transform("transformers-detect-all")?; + decoder.concretize_symbols([("T", 1)])?; let decoder = gpu.prepare(decoder)?; - let joint = nnef.load(model_dir.join("joint.nnef.tgz"))?; + let mut joint = nnef.load(model_dir.join("joint.nnef.tgz"))?; + joint.transform("transformers-detect-all")?; + joint.concretize_symbols([("R", 1), ("U", 1)])?; let joint = gpu.prepare(joint)?; Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) @@ -171,9 +201,7 @@ fn collect_wavs(inputs: &[PathBuf]) -> Vec { fn load_wav(path: &Path) -> Result<(Vec, f64)> { let mut wav_reader = hound::WavReader::open(path)?; let sample_rate = wav_reader.spec().sample_rate as f64; - let samples: Vec = wav_reader.samples::() - .map(|x| x.unwrap() as f32) - .collect(); + let samples: Vec = wav_reader.samples::().map(|x| x.unwrap() as f32).collect(); let audio_duration_s = samples.len() as f64 / sample_rate; Ok((samples, audio_duration_s)) } @@ -185,18 +213,28 @@ fn clean(s: &str) -> String { fn decoder_label(args: &Args) -> String { match args.decoder { Decoder::Greedy => "greedy".to_string(), - Decoder::Beam => format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k), - Decoder::Fbsd => format!("fbsd beam_size={} beam_dur_k={} max_symbols={}", args.fbsd.fbsd_beam_size, args.fbsd.fbsd_beam_dur_k, args.fbsd.fbsd_max_symbols_per_frame), - Decoder::Alsd => format!("alsd beam_size={} beam_dur_k={} u_max={}", args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_u_max), + Decoder::Beam => { + format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k) + } + Decoder::Fbsd => format!( + "fbsd beam_size={} beam_dur_k={} max_symbols={}", + args.fbsd.fbsd_beam_size, + args.fbsd.fbsd_beam_dur_k, + args.fbsd.fbsd_max_symbols_per_frame + ), + Decoder::Alsd => format!( + "alsd beam_size={} beam_dur_k={} u_max={}", + args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_u_max + ), } } fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, DecodingStats)> { match args.decoder { Decoder::Greedy => greedy::transcribe_greedy(model, wav), - Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), - Decoder::Fbsd => fbsd::transcribe_fbsd(model, wav, &args.fbsd), - Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), + Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), + Decoder::Fbsd => fbsd::transcribe_fbsd(model, wav, &args.fbsd), + Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), } } @@ -214,17 +252,20 @@ impl SearchConfig { match self { SearchConfig::Greedy => "greedy".to_owned(), SearchConfig::Beam(c) => format!("beam_{}_{}", c.beam_size, c.beam_dur_k), - SearchConfig::Fbsd(c) => format!("fbsd_{}_{}_{}", c.fbsd_beam_size, c.fbsd_beam_dur_k, c.fbsd_max_symbols_per_frame), + SearchConfig::Fbsd(c) => format!( + "fbsd_{}_{}_{}", + c.fbsd_beam_size, c.fbsd_beam_dur_k, c.fbsd_max_symbols_per_frame + ), SearchConfig::Alsd(c) => format!("alsd_{}_{}", c.alsd_beam_size, c.alsd_beam_dur_k), } } fn run(&self, model: &TdtModel, wav: &[f32]) -> Result<(String, DecodingStats)> { match self { - SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), - SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), - SearchConfig::Fbsd(c) => fbsd::transcribe_fbsd(model, wav, c), - SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), + SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), + SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), + SearchConfig::Fbsd(c) => fbsd::transcribe_fbsd(model, wav, c), + SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), } } } @@ -234,10 +275,18 @@ fn search_configs() -> Vec { SearchConfig::Beam(beam::BeamConfig { beam_size, beam_dur_k }) } fn f(beam_size: usize, beam_dur_k: usize, max_symbols: usize) -> SearchConfig { - SearchConfig::Fbsd(fbsd::FbsdConfig { fbsd_beam_size: beam_size, fbsd_beam_dur_k: beam_dur_k, fbsd_max_symbols_per_frame: max_symbols }) + SearchConfig::Fbsd(fbsd::FbsdConfig { + fbsd_beam_size: beam_size, + fbsd_beam_dur_k: beam_dur_k, + fbsd_max_symbols_per_frame: max_symbols, + }) } fn a(beam_size: usize, beam_dur_k: usize) -> SearchConfig { - SearchConfig::Alsd(alsd::AlsdConfig { alsd_beam_size: beam_size, alsd_beam_dur_k: beam_dur_k, alsd_u_max: 50 }) + SearchConfig::Alsd(alsd::AlsdConfig { + alsd_beam_size: beam_size, + alsd_beam_dur_k: beam_dur_k, + alsd_u_max: 50, + }) } vec![ SearchConfig::Greedy, @@ -280,12 +329,8 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { let configs = search_configs(); let mp = MultiProgress::new(); - let cfg_style = ProgressStyle::with_template( - "Configs {bar:40} {pos:>3}/{len} {msg}" - ).unwrap(); - let file_style = ProgressStyle::with_template( - "Files {bar:40} {pos:>3}/{len}" - ).unwrap(); + let cfg_style = ProgressStyle::with_template("Configs {bar:40} {pos:>3}/{len} {msg}").unwrap(); + let file_style = ProgressStyle::with_template("Files {bar:40} {pos:>3}/{len}").unwrap(); let cfg_bar = mp.add(ProgressBar::new(configs.len() as u64)); cfg_bar.set_style(cfg_style); @@ -312,8 +357,13 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { for wav_path in wavs { let (wav, audio_s) = load_wav(wav_path)?; - let reference = std::fs::read_to_string(wav_path.with_extension("txt")) - .with_context(|| format!("no ground-truth transcript for {} (run with --write-gt first)", wav_path.display()))?; + let reference = + std::fs::read_to_string(wav_path.with_extension("txt")).with_context(|| { + format!( + "no ground-truth transcript for {} (run with --write-gt first)", + wav_path.display() + ) + })?; let reference = reference.trim_end_matches('\n').to_owned(); let t = Instant::now(); @@ -323,10 +373,12 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { total_audio_s += audio_s; total_elapsed_s += elapsed; total += 1; - if clean(&transcript) == reference { exact += 1; } - pre_us += stats.preprocessor.total_us; - enc_us += stats.encoder.total_us; - dec_us += stats.decoder.total_us; + if clean(&transcript) == reference { + exact += 1; + } + pre_us += stats.preprocessor.total_us; + enc_us += stats.encoder.total_us; + dec_us += stats.decoder.total_us; joint_us += stats.joint.total_us; file_bar.inc(1); @@ -338,9 +390,19 @@ fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { let pct = |us: u64| if total_us > 0 { us as f64 / total_us as f64 * 100.0 } else { 0.0 }; let nn_us = pre_us + enc_us + dec_us + joint_us; let host_us = total_us.saturating_sub(nn_us); - mp.suspend(|| println!("{}\t{:.4}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}", - label, epr, rtfx, - pct(pre_us), pct(enc_us), pct(dec_us), pct(joint_us), pct(host_us))); + mp.suspend(|| { + println!( + "{}\t{:.4}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}", + label, + epr, + rtfx, + pct(pre_us), + pct(enc_us), + pct(dec_us), + pct(joint_us), + pct(host_us) + ) + }); cfg_bar.inc(1); } @@ -362,9 +424,7 @@ fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Res } let pb = if no_details { let pb = ProgressBar::new(wavs.len() as u64); - pb.set_style( - ProgressStyle::with_template("Writing GT {bar:40} {pos:>3}/{len}").unwrap() - ); + pb.set_style(ProgressStyle::with_template("Writing GT {bar:40} {pos:>3}/{len}").unwrap()); Some(pb) } else { None @@ -380,7 +440,9 @@ fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Res eprintln!("{} -> {}", wav_path.display(), txt_path.display()); } } - if let Some(pb) = pb { pb.finish(); } + if let Some(pb) = pb { + pb.finish(); + } eprintln!("wrote {} transcript(s)", wavs.len()); Ok(()) } @@ -417,9 +479,7 @@ fn main() -> anyhow::Result<()> { let pb = if args.no_details { let pb = ProgressBar::new(wavs.len() as u64); - pb.set_style( - ProgressStyle::with_template("Decoding {bar:40} {pos:>3}/{len}").unwrap() - ); + pb.set_style(ProgressStyle::with_template("Decoding {bar:40} {pos:>3}/{len}").unwrap()); Some(pb) } else { None @@ -431,8 +491,13 @@ fn main() -> anyhow::Result<()> { for wav_path in &wavs { let (wav, audio_s) = load_wav(wav_path)?; - let reference = std::fs::read_to_string(wav_path.with_extension("txt")) - .with_context(|| format!("no ground-truth transcript for {} (run with --write-gt first)", wav_path.display()))?; + let reference = + std::fs::read_to_string(wav_path.with_extension("txt")).with_context(|| { + format!( + "no ground-truth transcript for {} (run with --write-gt first)", + wav_path.display() + ) + })?; let t = Instant::now(); let (transcript, stats) = run_decoder(&model, &wav, &args)?; @@ -442,7 +507,9 @@ fn main() -> anyhow::Result<()> { total_elapsed_s += elapsed; let ok = clean(&transcript) == reference.trim_end_matches('\n'); - if ok { exact += 1; } + if ok { + exact += 1; + } if let Some(ref pb) = pb { pb.inc(1); @@ -450,15 +517,22 @@ fn main() -> anyhow::Result<()> { let mark = if ok { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; let elapsed_ms = elapsed * 1000.0; let nn_ms = stats.nn_ms(); - eprintln!("{} {mark} {audio_s:.1}s {elapsed_ms:.1}ms RTFx={:.1} {}", - wav_path.display(), audio_s / elapsed, clean(&transcript)); + eprintln!( + "{} {mark} {audio_s:.1}s {elapsed_ms:.1}ms RTFx={:.1} {}", + wav_path.display(), + audio_s / elapsed, + clean(&transcript) + ); if !ok { eprintln!(" ref: {}", clean(&reference)); eprintln!(" got: {}", clean(&transcript)); } if args.stats { - eprintln!(" {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", - audio_s / elapsed, elapsed_ms - nn_ms); + eprintln!( + " {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", + audio_s / elapsed, + elapsed_ms - nn_ms + ); eprintln!(" [preprocessor] {:?}", stats.preprocessor); eprintln!(" [encoder] {:?}", stats.encoder); eprintln!(" [decoder] {:?}", stats.decoder); @@ -467,8 +541,16 @@ fn main() -> anyhow::Result<()> { } } - if let Some(pb) = pb { pb.finish(); } - eprintln!("{} {}/{} exact RTFx={:.1}", label, exact, wavs.len(), total_audio_s / total_elapsed_s); + if let Some(pb) = pb { + pb.finish(); + } + eprintln!( + "{} {}/{} exact RTFx={:.1}", + label, + exact, + wavs.len(), + total_audio_s / total_elapsed_s + ); Ok(()) } From 4ced0d36e0cd523dc79741ddf7a17e52604249e8 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 5 Mar 2026 16:01:08 +0100 Subject: [PATCH 20/20] use prob of best skip, even if it is 0 --- examples/nemo-parakeet-asr/src/alsd.rs | 46 ++++++++++---------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs index a40abb0ac0..1326e3daf8 100644 --- a/examples/nemo-parakeet-asr/src/alsd.rs +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -61,14 +61,8 @@ pub fn transcribe_alsd( let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; stats.decoder.record(batch, t.elapsed()); - let mut beam: Vec = vec![AlsdHyp { - score: 0.0, - tokens: vec![], - current_frame: 0, - dec_out, - state_0, - state_1, - }]; + let mut beam: Vec = + vec![AlsdHyp { score: 0.0, tokens: vec![], current_frame: 0, dec_out, state_0, state_1 }]; let mut final_hyps: Vec = Vec::new(); @@ -102,10 +96,8 @@ pub fn transcribe_alsd( .try_into()? }; let dec_out_batch: Value = { - let views: Vec<_> = active - .iter() - .map(|h| h.dec_out.view::()) - .collect::>>()?; + let views: Vec<_> = + active.iter().map(|h| h.dec_out.view::()).collect::>>()?; let hidden = views[0].shape()[1]; Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) .try_into()? @@ -127,20 +119,24 @@ pub fn transcribe_alsd( let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); - let best_dur = dur_log_probs.iter().enumerate() - .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + let best_dur = dur_log_probs + .iter() + .enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)) + .map(|(i, _)| i) + .unwrap_or(0); per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); // Blank + top-k duration expansions (decoder state unchanged). let mut ds: Vec<(usize, f32)> = - (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + (0..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); ds.truncate(cfg.alsd_beam_dur_k); for (di, dlp) in ds { next.push(AlsdHyp { score: active[bi].score + log_probs[model.blank_id] + dlp, tokens: active[bi].tokens.clone(), - current_frame: active[bi].current_frame + di, + current_frame: active[bi].current_frame + di.max(1), dec_out: active[bi].dec_out.clone(), state_0: active[bi].state_0.clone(), state_1: active[bi].state_1.clone(), @@ -172,20 +168,16 @@ pub fn transcribe_alsd( let tokens_batch: Value = Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; let s0_batch: Value = { - let views: Vec<_> = active - .iter() - .map(|h| h.state_0.view::()) - .collect::>>()?; + let views: Vec<_> = + active.iter().map(|h| h.state_0.view::()).collect::>>()?; Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { views[expansion_hyp_idxs[i]][[l, 0, h]] }) .try_into()? }; let s1_batch: Value = { - let views: Vec<_> = active - .iter() - .map(|h| h.state_1.view::()) - .collect::>>()?; + let views: Vec<_> = + active.iter().map(|h| h.state_1.view::()).collect::>>()?; Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { views[expansion_hyp_idxs[i]][[l, 0, h]] }) @@ -205,10 +197,8 @@ pub fn transcribe_alsd( for &(ti, lp) in ts { let new_dec_out: Value = dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; - let new_s0: Value = - s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; - let new_s1: Value = - s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s0: Value = s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; let mut new_tokens = active[bi].tokens.clone(); new_tokens.push(ti); next.push(AlsdHyp {