diff --git a/examples/nemo-parakeet-asr/Cargo.toml b/examples/nemo-parakeet-asr/Cargo.toml index a7fd2d6e53..26cc34caa4 100644 --- a/examples/nemo-parakeet-asr/Cargo.toml +++ b/examples/nemo-parakeet-asr/Cargo.toml @@ -5,8 +5,10 @@ edition = "2024" [dependencies] anyhow.workspace = true +clap = { version = "4", features = ["derive"] } float-ord.workspace = true hound = "3.5.1" +indicatif = "0.17" itertools.workspace = true serde_json.workspace = true tract-rs.workspace = true diff --git a/examples/nemo-parakeet-asr/src/alsd.rs b/examples/nemo-parakeet-asr/src/alsd.rs new file mode 100644 index 0000000000..1326e3daf8 --- /dev/null +++ b/examples/nemo-parakeet-asr/src/alsd.rs @@ -0,0 +1,235 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct AlsdConfig { + /// Beam width for ALSD decoding + #[arg(long, default_value_t = 4)] + pub alsd_beam_size: usize, + + /// Duration candidates per hypothesis in ALSD + #[arg(long, default_value_t = 2)] + pub alsd_beam_dur_k: usize, + + /// Maximum expected output token count (U_max in the paper) + #[arg(long, default_value_t = 50)] + pub alsd_u_max: usize, +} + +struct AlsdHyp { + score: f32, + tokens: Vec, + current_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_alsd( + model: &crate::TdtModel, + wav: &[f32], + cfg: &AlsdConfig, +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; + stats.decoder.record(batch, t.elapsed()); + + let mut beam: Vec = + vec![AlsdHyp { score: 0.0, tokens: vec![], current_frame: 0, dec_out, state_0, state_1 }]; + + let mut final_hyps: Vec = Vec::new(); + + // Outer loop: alignment steps 0 .. T + U_max. + // All hypotheses in `beam` are at the same alignment step. + // At each step every active hyp (current_frame < T) gets one expansion + // (blank+duration OR one non-blank token), so a single step advances + // the alignment length by exactly 1. + for _step in 0..(max_frames + cfg.alsd_u_max) { + // Separate completed hypotheses (exhausted all frames) from active ones. + let mut active: Vec = Vec::new(); + for h in beam.drain(..) { + if h.current_frame >= max_frames { + final_hyps.push(h); + } else { + active.push(h); + } + } + if active.is_empty() { + break; + } + + let b = active.len(); + + // 1. JOINT: batched over all active hypotheses (each at its own frame). + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(bi, e, _)| { + enc_arr[[0, e, active[bi].current_frame]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = + active.iter().map(|h| h.dec_out.view::()).collect::>>()?; + let hidden = views[0].shape()[1]; + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; + stats.joint.record(b, t.elapsed()); + + // 2. Per-hyp: blank+duration children (no decoder update) and + // non-blank token candidates (need decoder update). + let mut next: Vec = Vec::new(); + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + let best_dur = dur_log_probs + .iter() + .enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)) + .map(|(i, _)| i) + .unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + + // Blank + top-k duration expansions (decoder state unchanged). + let mut ds: Vec<(usize, f32)> = + (0..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.alsd_beam_dur_k); + for (di, dlp) in ds { + next.push(AlsdHyp { + score: active[bi].score + log_probs[model.blank_id] + dlp, + tokens: active[bi].tokens.clone(), + current_frame: active[bi].current_frame + di.max(1), + dec_out: active[bi].dec_out.clone(), + state_0: active[bi].state_0.clone(), + state_1: active[bi].state_1.clone(), + }); + } + + // Top-k non-blank token candidates (decoder step needed). + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.alsd_beam_size); + per_hyp_token_scores.push(ts); + } + } + + // 3. DECODER: single batched call over all non-blank expansions. + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + if n > 0 { + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = + active.iter().map(|h| h.state_0.view::()).collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = + active.iter().map(|h| h.state_1.view::()).collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; + stats.decoder.record(n, t.elapsed()); + + let dec_arr = dec_out_b.view::()?; + let s0_arr = s0_b.view::()?; + let s1_arr = s1_b.view::()?; + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = active[bi].tokens.clone(); + new_tokens.push(ti); + next.push(AlsdHyp { + score: active[bi].score + lp + best_dur_lp, + tokens: new_tokens, + current_frame: active[bi].current_frame + best_dur, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + } + + // 4. Prune: sort by score, deduplicate on (tokens, frame), keep top beam_size. + // Two hyps with identical tokens and frame have identical decoder state; + // keep only the highest-scoring one. + next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize)>::new(); + next.retain(|h| seen.insert((h.tokens.clone(), h.current_frame))); + next.truncate(cfg.alsd_beam_size); + beam = next; + } + + // Any hyps remaining in beam that didn't finish (step limit hit) go to final. + final_hyps.extend(beam); + + let best = final_hyps + .into_iter() + .max_by_key(|h| FloatOrd(h.score)) + .ok_or_else(|| anyhow!("no hypotheses survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/beam.rs b/examples/nemo-parakeet-asr/src/beam.rs new file mode 100644 index 0000000000..4ad2d367dc --- /dev/null +++ b/examples/nemo-parakeet-asr/src/beam.rs @@ -0,0 +1,239 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct BeamConfig { + /// Number of active hypotheses to keep after each pruning step + #[arg(long, default_value_t = 4)] + pub beam_size: usize, + + /// Number of duration candidates to expand per hypothesis + #[arg(long, default_value_t = 2)] + pub beam_dur_k: usize, +} + +struct Beam { + score: f32, + tokens: Vec, + last_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_beam( + model: &crate::TdtModel, + wav: &[f32], + cfg: &BeamConfig, +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; + stats.decoder.record(batch, t.elapsed()); + + let mut all_beams: Vec = vec![Beam { + score: 0.0, + tokens: vec![], + last_frame: 0, + dec_out, + state_0, + state_1, + }]; + + for frame_ix in 0..max_frames { + let mut hyps: Vec = Vec::new(); + let mut kept: Vec = Vec::new(); + for b in all_beams.drain(..) { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + + while !hyps.is_empty() { + let b = hyps.len(); + + // 1. JOINT: single call batched over all B active hypotheses + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(_, e, _)| { + enc_arr[[0, e, frame_ix]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; + stats.joint.record(b, t.elapsed()); + + // 2. Per-hyp: token scores + duration expansions into kept + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; // [b, vocab+dur] + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.beam_size); + per_hyp_token_scores.push(ts); + + let best_dur = dur_log_probs.iter().enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.beam_dur_k); + for (di, dlp) in ds { + kept.push(Beam { + score: hyps[bi].score + log_probs[model.blank_id] + dlp, + tokens: hyps[bi].tokens.clone(), + last_frame: frame_ix + di, + dec_out: hyps[bi].dec_out.clone(), + state_0: hyps[bi].state_0.clone(), + state_1: hyps[bi].state_1.clone(), + }); + } + } + } // logits_arr dropped + + // 3. DECODER: single call batched over all N token expansions + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = hyps + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; + stats.decoder.record(n, t.elapsed()); + + // 4. Slice and build new beams + let new_hyps: Vec = { + let dec_arr = dec_out_b.view::()?; // [N, hidden] + let s0_arr = s0_b.view::()?; // [2, N, 640] + let s1_arr = s1_b.view::()?; // [2, N, 640] + let mut out = Vec::with_capacity(n); + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = hyps[bi].tokens.clone(); + new_tokens.push(ti); + out.push(Beam { + score: hyps[bi].score + lp + best_dur_lp, + tokens: new_tokens, + last_frame: frame_ix + best_dur, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + out + }; + hyps = new_hyps; + + // 5. Fuse duplicates then prune to BEAM_SIZE. + // After sorting by score, the first occurrence of each (tokens, last_frame) + // key is the best one; retain only that first occurrence. + let mut all: Vec = hyps.drain(..).chain(kept.drain(..)).collect(); + all.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize)>::new(); + all.retain(|h| seen.insert((h.tokens.clone(), h.last_frame))); + all.truncate(cfg.beam_size); + for b in all { + if b.last_frame == frame_ix { + hyps.push(b); + } else { + kept.push(b); + } + } + } + + all_beams = kept; + } + + let best = all_beams + .into_iter() + .max_by_key(|b| FloatOrd(b.score)) + .ok_or_else(|| anyhow!("no beams survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/fbsd.rs b/examples/nemo-parakeet-asr/src/fbsd.rs new file mode 100644 index 0000000000..383e162693 --- /dev/null +++ b/examples/nemo-parakeet-asr/src/fbsd.rs @@ -0,0 +1,246 @@ +use std::time::Instant; + +use anyhow::*; +use clap::Args; +use float_ord::FloatOrd; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +#[derive(Args, Clone)] +pub struct FbsdConfig { + /// Beam width for FBSD decoding + #[arg(long, default_value_t = 4)] + pub fbsd_beam_size: usize, + + /// Duration candidates per hypothesis in FBSD + #[arg(long, default_value_t = 2)] + pub fbsd_beam_dur_k: usize, + + /// Max non-blank tokens emitted per frame per hypothesis + #[arg(long, default_value_t = 10)] + pub fbsd_max_symbols_per_frame: usize, +} + +struct FbsdHyp { + score: f32, + tokens: Vec, + current_frame: usize, + symbols_this_frame: usize, + dec_out: Value, + state_0: Value, + state_1: Value, +} + +pub fn transcribe_fbsd( + model: &crate::TdtModel, + wav: &[f32], + cfg: &FbsdConfig, +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let batch = encoded.shape()[0]; + let max_frames = encoded.shape()[2]; + let enc_dim = encoded.shape()[1]; + + let init_token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; + let init_s0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let init_s1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let t = Instant::now(); + let [dec_out, state_0, state_1] = model.run_decoder(init_token, init_s0, init_s1)?; + stats.decoder.record(batch, t.elapsed()); + + let mut beam: Vec = vec![FbsdHyp { + score: 0.0, + tokens: vec![], + current_frame: 0, + symbols_this_frame: 0, + dec_out, + state_0, + state_1, + }]; + + loop { + // Split into active (still have frames to consume) and completed + let mut active: Vec = Vec::new(); + let mut completed: Vec = Vec::new(); + for h in beam.drain(..) { + if h.current_frame < max_frames { + active.push(h); + } else { + completed.push(h); + } + } + + if active.is_empty() { + beam = completed; + break; + } + + let b = active.len(); + + // 1. JOINT: each hypothesis uses its own current_frame + let frame_batch: Value = { + let enc_arr = encoded.view(); + Array3::::from_shape_fn((b, enc_dim, 1), |(bi, e, _)| { + enc_arr[[0, e, active[bi].current_frame]] + }) + .try_into()? + }; + let dec_out_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.dec_out.view::()) + .collect::>>()?; + let hidden = views[0].shape()[1]; // dec_out is [1, hidden, 1] + Array3::::from_shape_fn((b, hidden, 1), |(bi, h, _)| views[bi][[0, h, 0]]) + .try_into()? + }; + let t = Instant::now(); + let logits_b = model.run_joint(frame_batch, dec_out_batch)?; + stats.joint.record(b, t.elapsed()); + + // 2. Per-hyp: blank+duration → next; non-blank → expansion list + let mut next: Vec = completed; + let mut per_hyp_token_scores: Vec> = Vec::with_capacity(b); + let mut per_hyp_best_dur: Vec<(usize, f32)> = Vec::with_capacity(b); + { + let logits_arr = logits_b.view::()?; // [b, vocab+dur] + for bi in 0..b { + let row = logits_arr.index_axis(Axis(0), bi); + let row_slice = row.as_slice().unwrap(); + let log_probs = crate::log_softmax(&row_slice[0..=model.blank_id]); + let dur_log_probs = crate::log_softmax(&row_slice[model.blank_id + 1..]); + + let best_dur = dur_log_probs.iter().enumerate() + .max_by_key(|&(_, &v)| FloatOrd(v)).map(|(i, _)| i).unwrap_or(0); + per_hyp_best_dur.push((best_dur, dur_log_probs[best_dur])); + + // Blank + duration expansions advance the frame and reset symbol count + let mut ds: Vec<(usize, f32)> = + (1..dur_log_probs.len()).map(|di| (di, dur_log_probs[di])).collect(); + ds.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ds.truncate(cfg.fbsd_beam_dur_k); + for (di, dlp) in ds { + next.push(FbsdHyp { + score: active[bi].score + log_probs[model.blank_id] + dlp, + tokens: active[bi].tokens.clone(), + current_frame: active[bi].current_frame + di, + symbols_this_frame: 0, + dec_out: active[bi].dec_out.clone(), + state_0: active[bi].state_0.clone(), + state_1: active[bi].state_1.clone(), + }); + } + + // Non-blank expansions: gated by per-frame symbol cap + if active[bi].symbols_this_frame < cfg.fbsd_max_symbols_per_frame { + let mut ts: Vec<(usize, f32)> = + (0..model.blank_id).map(|ti| (ti, log_probs[ti])).collect(); + ts.sort_by(|a, b| FloatOrd(b.1).cmp(&FloatOrd(a.1))); + ts.truncate(cfg.fbsd_beam_size); + per_hyp_token_scores.push(ts); + } else { + per_hyp_token_scores.push(vec![]); + } + } + } // logits_arr dropped + + // 3. DECODER: single batched call over all non-blank expansions + let expansion_hyp_idxs: Vec = per_hyp_token_scores + .iter() + .enumerate() + .flat_map(|(bi, ts)| std::iter::repeat(bi).take(ts.len())) + .collect(); + let token_ids: Vec = per_hyp_token_scores + .iter() + .flat_map(|ts| ts.iter().map(|&(ti, _)| ti as i32)) + .collect(); + let n = token_ids.len(); + + if n > 0 { + let tokens_batch: Value = + Array2::::from_shape_fn((n, 1), |(i, _)| token_ids[i]).try_into()?; + let s0_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_0.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + let s1_batch: Value = { + let views: Vec<_> = active + .iter() + .map(|h| h.state_1.view::()) + .collect::>>()?; + Array3::::from_shape_fn((2, n, 640), |(l, i, h)| { + views[expansion_hyp_idxs[i]][[l, 0, h]] + }) + .try_into()? + }; + + let t = Instant::now(); + let [dec_out_b, s0_b, s1_b] = model.run_decoder(tokens_batch, s0_batch, s1_batch)?; + stats.decoder.record(n, t.elapsed()); + + // 4. Slice per-expansion outputs and push into next + let dec_arr = dec_out_b.view::()?; // [N, hidden, 1] + let s0_arr = s0_b.view::()?; // [2, N, 640] + let s1_arr = s1_b.view::()?; // [2, N, 640] + let mut i = 0; + for (bi, ts) in per_hyp_token_scores.iter().enumerate() { + let (best_dur, best_dur_lp) = per_hyp_best_dur[bi]; + for &(ti, lp) in ts { + let new_dec_out: Value = + dec_arr.slice_axis(Axis(0), (i..i + 1).into()).try_into()?; + let new_s0: Value = + s0_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let new_s1: Value = + s1_arr.slice_axis(Axis(1), (i..i + 1).into()).try_into()?; + let mut new_tokens = active[bi].tokens.clone(); + new_tokens.push(ti); + next.push(FbsdHyp { + score: active[bi].score + lp + best_dur_lp, + tokens: new_tokens, + current_frame: active[bi].current_frame + best_dur, + symbols_this_frame: if best_dur > 0 { 0 } else { active[bi].symbols_this_frame + 1 }, + dec_out: new_dec_out, + state_0: new_s0, + state_1: new_s1, + }); + i += 1; + } + } + } + + // 5. Fuse duplicates then prune to fbsd_beam_size. + // Key is (tokens, current_frame, symbols_this_frame): same tokens → same decoder + // state; same frame → same encoder embedding; same symbol count → same gate behavior. + next.sort_by(|a, b| FloatOrd(b.score).cmp(&FloatOrd(a.score))); + let mut seen = std::collections::HashSet::<(Vec, usize, usize)>::new(); + next.retain(|h| seen.insert((h.tokens.clone(), h.current_frame, h.symbols_this_frame))); + next.truncate(cfg.fbsd_beam_size); + beam = next; + } + + let best = beam + .into_iter() + .max_by_key(|b| FloatOrd(b.score)) + .ok_or_else(|| anyhow!("no beams survived"))?; + Ok((best.tokens.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/greedy.rs b/examples/nemo-parakeet-asr/src/greedy.rs new file mode 100644 index 0000000000..97c3e0896c --- /dev/null +++ b/examples/nemo-parakeet-asr/src/greedy.rs @@ -0,0 +1,62 @@ +use std::time::Instant; + +use anyhow::*; +use itertools::Itertools; +use tract_rs::prelude::tract_ndarray::prelude::*; +use tract_rs::prelude::*; + +pub fn transcribe_greedy( + model: &crate::TdtModel, + wav: &[f32], +) -> Result<(String, crate::DecodingStats)> { + let mut stats = crate::DecodingStats::default(); + + let samples: Value = Value::from_slice(&[1, wav.len()], wav)?; + let len: Value = arr1(&[wav.len() as i64]).try_into()?; + + let t = Instant::now(); + let [features, feat_len] = model.run_preprocessor(samples, len)?; + stats.preprocessor.record(1, t.elapsed()); + + let t = Instant::now(); + let [encoded, _lens] = model.run_encoder(features, feat_len)?; + stats.encoder.record(1, t.elapsed()); + + let encoded: ArrayD = encoded.view()?.into_owned(); + let max_frames = encoded.shape()[2]; + let max_len = max_frames * 6 + 10; + + let mut hyp = vec![]; + let mut frame_ix = 0; + let mut token = Value::from_slice(&[1, 1], &[model.blank_id as i32])?; + let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; + let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; + + let t = Instant::now(); + [token, state_0, state_1] = model.run_decoder(token, state_0, state_1)?; + stats.decoder.record(1, t.elapsed()); + + while hyp.len() < max_len && frame_ix < max_frames { + let frame: Value = + encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; + let t = Instant::now(); + let logits = model.run_joint(frame, token.clone())?; + stats.joint.record(1, t.elapsed()); + let logits = logits.view::()?; + let logits = logits.as_slice().unwrap(); + let token_id = crate::argmax(&logits[0..model.blank_id + 1]).unwrap(); + let dur = crate::argmax(&logits[model.blank_id + 1..]).unwrap_or(0); + if token_id == model.blank_id { + frame_ix += dur.max(1); + } else { + hyp.push(token_id); + frame_ix += dur; + token = Value::from_slice(&[1, 1], &[token_id as i32])?; + let t = Instant::now(); + [token, state_0, state_1] = model.run_decoder(token, state_0, state_1)?; + stats.decoder.record(1, t.elapsed()); + } + } + + Ok((hyp.into_iter().map(|t| model.vocab[t].as_str()).join(""), stats)) +} diff --git a/examples/nemo-parakeet-asr/src/main.rs b/examples/nemo-parakeet-asr/src/main.rs index 82ed8a55c4..bdc8a8118b 100644 --- a/examples/nemo-parakeet-asr/src/main.rs +++ b/examples/nemo-parakeet-asr/src/main.rs @@ -1,83 +1,556 @@ use std::fs::File; +use std::path::{Path, PathBuf}; +use std::time::Instant; use anyhow::*; +use clap::Parser; use float_ord::FloatOrd; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use itertools::Itertools; -use tract_rs::prelude::tract_ndarray::prelude::*; use tract_rs::prelude::*; +use tract_rs::Nnef; -fn argmax(slice: &[f32]) -> Option { +mod alsd; +mod beam; +mod fbsd; +mod greedy; + +#[derive(clap::ValueEnum, Clone)] +enum Decoder { + Greedy, + Beam, + Fbsd, + Alsd, +} + +#[derive(Parser)] +#[command(about = "NeMo Parakeet ASR inference")] +struct Args { + #[command(flatten)] + beam: beam::BeamConfig, + #[command(flatten)] + fbsd: fbsd::FbsdConfig, + #[command(flatten)] + alsd: alsd::AlsdConfig, + #[arg(long, value_enum, default_value_t = Decoder::Greedy)] + decoder: Decoder, + #[arg(long)] + stats: bool, + #[arg(long)] + no_details: bool, + /// Run ground-truth decoder and write transcript to a .txt file beside each wav + #[arg(long)] + write_gt: bool, + /// Sweep hardcoded decoder configs and print TSV results to stdout + #[arg(long)] + param_search: bool, + #[arg(required = true)] + inputs: Vec, +} + +#[derive(Default)] +pub(crate) struct CallStats { + calls: u32, + total_batch: u64, + pub(crate) total_us: u64, +} + +impl CallStats { + pub(crate) fn record(&mut self, batch: usize, elapsed: std::time::Duration) { + self.calls += 1; + self.total_batch += batch as u64; + self.total_us += elapsed.as_micros() as u64; + } + + pub(crate) fn total_ms(&self) -> f64 { + self.total_us as f64 / 1000.0 + } +} + +impl std::fmt::Debug for CallStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let avg_batch = + if self.calls == 0 { 0.0 } else { self.total_batch as f64 / self.calls as f64 }; + let avg_ms = + if self.calls == 0 { 0.0 } else { self.total_us as f64 / self.calls as f64 / 1000.0 }; + let total_ms = self.total_us as f64 / 1000.0; + write!( + f, + "calls={:5} avg_batch={avg_batch:.1} avg={avg_ms:.3}ms total={total_ms:.1}ms", + self.calls + ) + } +} + +#[derive(Default)] +pub(crate) struct DecodingStats { + pub(crate) preprocessor: CallStats, + pub(crate) encoder: CallStats, + pub(crate) decoder: CallStats, + pub(crate) joint: CallStats, +} + +impl DecodingStats { + pub(crate) fn nn_ms(&self) -> f64 { + self.preprocessor.total_ms() + + self.encoder.total_ms() + + self.decoder.total_ms() + + self.joint.total_ms() + } +} + +pub(crate) fn argmax(slice: &[f32]) -> Option { slice.into_iter().position_max_by_key(|x| FloatOrd(**x)) } -fn main() -> anyhow::Result<()> { - let config: serde_json::Value = - serde_json::from_reader(File::open("assets/model/model_config.json")?)?; - let blank_id = config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; - let vocab = config.pointer("/joint/vocabulary").unwrap().as_array().unwrap(); - let vocab: Vec<&str> = vocab.iter().map(|v| v.as_str().unwrap()).collect(); +pub(crate) fn log_softmax(xs: &[f32]) -> Vec { + let max = xs.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let lse = xs.iter().map(|&x| (x - max).exp()).sum::().ln(); + xs.iter().map(|&x| x - max - lse).collect() +} + +pub(crate) struct TdtModel { + preprocessor: Runnable, + encoder: Runnable, + decoder: Runnable, + joint: Runnable, + pub(crate) vocab: Vec, + pub(crate) blank_id: usize, +} + +impl TdtModel { + #[inline(never)] + pub(crate) fn run_preprocessor(&self, samples: Value, len: Value) -> Result<[Value; 2]> { + Ok(self.preprocessor.run([samples, len])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_encoder(&self, features: Value, feat_len: Value) -> Result<[Value; 2]> { + Ok(self.encoder.run([features, feat_len])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_decoder(&self, token: Value, s0: Value, s1: Value) -> Result<[Value; 3]> { + Ok(self.decoder.run([token, s0, s1])?.try_into().unwrap()) + } + + #[inline(never)] + pub(crate) fn run_joint(&self, frame: Value, dec_out: Value) -> Result { + let [logits] = self.joint.run([frame, dec_out])?.try_into().unwrap(); + Ok(logits) + } + + fn load(model_dir: impl AsRef, nnef: &Nnef, gpu: &Runtime) -> Result { + let model_dir = model_dir.as_ref(); + let config: serde_json::Value = + serde_json::from_reader(File::open(model_dir.join("model_config.json"))?)?; + let blank_id = config.pointer("/decoder/vocab_size").unwrap().as_i64().unwrap() as usize; + let vocab = config.pointer("/joint/vocabulary").unwrap().as_array().unwrap(); + let vocab: Vec = vocab.iter().map(|v| v.as_str().unwrap().to_owned()).collect(); + + let mut preprocessor = nnef.load(model_dir.join("preprocessor.nnef.tgz"))?; + preprocessor.transform("transformers-detect-all")?; + let preprocessor = gpu.prepare(preprocessor)?; + + let mut encoder = nnef.load(model_dir.join("encoder.nnef.tgz"))?; + encoder.transform("transformers-detect-all")?; + let encoder = gpu.prepare(encoder)?; + + let mut decoder = nnef.load(model_dir.join("decoder.nnef.tgz"))?; + decoder.transform("transformers-detect-all")?; + decoder.concretize_symbols([("T", 1)])?; + let decoder = gpu.prepare(decoder)?; + + let mut joint = nnef.load(model_dir.join("joint.nnef.tgz"))?; + joint.transform("transformers-detect-all")?; + joint.concretize_symbols([("R", 1), ("U", 1)])?; + let joint = gpu.prepare(joint)?; + + Ok(TdtModel { preprocessor, encoder, decoder, joint, vocab, blank_id }) + } +} + +fn collect_wavs_from_dir(dir: &Path) -> Vec { + let mut results = Vec::new(); + if let Some(entries) = std::fs::read_dir(dir).ok() { + for entry in entries.flatten() { + let path: PathBuf = entry.path(); + if path.is_dir() { + results.extend(collect_wavs_from_dir(&path)); + } else if path.extension().and_then(|e: &std::ffi::OsStr| e.to_str()) == Some("wav") { + results.push(path); + } + } + } + results +} + +fn collect_wavs(inputs: &[PathBuf]) -> Vec { + let mut results = Vec::new(); + for input in inputs { + if input.is_dir() { + results.extend(collect_wavs_from_dir(input)); + } else { + results.push(input.clone()); + } + } + results.sort(); + results +} + +fn load_wav(path: &Path) -> Result<(Vec, f64)> { + let mut wav_reader = hound::WavReader::open(path)?; + let sample_rate = wav_reader.spec().sample_rate as f64; + let samples: Vec = wav_reader.samples::().map(|x| x.unwrap() as f32).collect(); + let audio_duration_s = samples.len() as f64 / sample_rate; + Ok((samples, audio_duration_s)) +} + +fn clean(s: &str) -> String { + s.replace('▁', " ").trim_start().to_owned() +} + +fn decoder_label(args: &Args) -> String { + match args.decoder { + Decoder::Greedy => "greedy".to_string(), + Decoder::Beam => { + format!("beam beam_size={} beam_dur_k={}", args.beam.beam_size, args.beam.beam_dur_k) + } + Decoder::Fbsd => format!( + "fbsd beam_size={} beam_dur_k={} max_symbols={}", + args.fbsd.fbsd_beam_size, + args.fbsd.fbsd_beam_dur_k, + args.fbsd.fbsd_max_symbols_per_frame + ), + Decoder::Alsd => format!( + "alsd beam_size={} beam_dur_k={} u_max={}", + args.alsd.alsd_beam_size, args.alsd.alsd_beam_dur_k, args.alsd.alsd_u_max + ), + } +} + +fn run_decoder(model: &TdtModel, wav: &[f32], args: &Args) -> Result<(String, DecodingStats)> { + match args.decoder { + Decoder::Greedy => greedy::transcribe_greedy(model, wav), + Decoder::Beam => beam::transcribe_beam(model, wav, &args.beam), + Decoder::Fbsd => fbsd::transcribe_fbsd(model, wav, &args.fbsd), + Decoder::Alsd => alsd::transcribe_alsd(model, wav, &args.alsd), + } +} + +// ─── SearchConfig ──────────────────────────────────────────────────────────── + +enum SearchConfig { + Greedy, + Beam(beam::BeamConfig), + Fbsd(fbsd::FbsdConfig), + Alsd(alsd::AlsdConfig), +} + +impl SearchConfig { + fn label(&self) -> String { + match self { + SearchConfig::Greedy => "greedy".to_owned(), + SearchConfig::Beam(c) => format!("beam_{}_{}", c.beam_size, c.beam_dur_k), + SearchConfig::Fbsd(c) => format!( + "fbsd_{}_{}_{}", + c.fbsd_beam_size, c.fbsd_beam_dur_k, c.fbsd_max_symbols_per_frame + ), + SearchConfig::Alsd(c) => format!("alsd_{}_{}", c.alsd_beam_size, c.alsd_beam_dur_k), + } + } + + fn run(&self, model: &TdtModel, wav: &[f32]) -> Result<(String, DecodingStats)> { + match self { + SearchConfig::Greedy => greedy::transcribe_greedy(model, wav), + SearchConfig::Beam(c) => beam::transcribe_beam(model, wav, c), + SearchConfig::Fbsd(c) => fbsd::transcribe_fbsd(model, wav, c), + SearchConfig::Alsd(c) => alsd::transcribe_alsd(model, wav, c), + } + } +} + +fn search_configs() -> Vec { + fn b(beam_size: usize, beam_dur_k: usize) -> SearchConfig { + SearchConfig::Beam(beam::BeamConfig { beam_size, beam_dur_k }) + } + fn f(beam_size: usize, beam_dur_k: usize, max_symbols: usize) -> SearchConfig { + SearchConfig::Fbsd(fbsd::FbsdConfig { + fbsd_beam_size: beam_size, + fbsd_beam_dur_k: beam_dur_k, + fbsd_max_symbols_per_frame: max_symbols, + }) + } + fn a(beam_size: usize, beam_dur_k: usize) -> SearchConfig { + SearchConfig::Alsd(alsd::AlsdConfig { + alsd_beam_size: beam_size, + alsd_beam_dur_k: beam_dur_k, + alsd_u_max: 50, + }) + } + vec![ + SearchConfig::Greedy, + b(1, 1), + b(2, 1), + b(2, 2), + b(4, 1), + b(4, 2), + b(4, 4), + b(8, 2), + b(8, 4), + f(1, 1, 10), + f(2, 1, 10), + f(2, 2, 10), + f(4, 1, 10), + f(4, 2, 10), + f(4, 4, 10), + f(8, 2, 10), + f(8, 4, 10), + f(4, 2, 3), + f(4, 2, 30), + a(1, 1), + a(2, 1), + a(2, 2), + a(4, 1), + a(4, 2), + a(4, 4), + a(8, 2), + a(8, 4), + ] +} + +fn param_search(model: &TdtModel, wavs: &[PathBuf]) -> Result<()> { + // Warmup + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + greedy::transcribe_greedy(model, &wav)?; + } + + let configs = search_configs(); + + let mp = MultiProgress::new(); + let cfg_style = ProgressStyle::with_template("Configs {bar:40} {pos:>3}/{len} {msg}").unwrap(); + let file_style = ProgressStyle::with_template("Files {bar:40} {pos:>3}/{len}").unwrap(); + + let cfg_bar = mp.add(ProgressBar::new(configs.len() as u64)); + cfg_bar.set_style(cfg_style); + let file_bar = mp.add(ProgressBar::new(wavs.len() as u64)); + file_bar.set_style(file_style); + + mp.suspend(|| println!("label\tEPR\tRTFx\tpre%\tenc%\tdec%\tjoint%\thost%")); + + for cfg in &configs { + let label = cfg.label(); + cfg_bar.set_message(label.clone()); + + file_bar.reset(); + file_bar.set_length(wavs.len() as u64); + + let mut total_audio_s = 0.0f64; + let mut total_elapsed_s = 0.0f64; + let mut exact = 0usize; + let mut total = 0usize; + let mut pre_us = 0u64; + let mut enc_us = 0u64; + let mut dec_us = 0u64; + let mut joint_us = 0u64; + + for wav_path in wavs { + let (wav, audio_s) = load_wav(wav_path)?; + let reference = + std::fs::read_to_string(wav_path.with_extension("txt")).with_context(|| { + format!( + "no ground-truth transcript for {} (run with --write-gt first)", + wav_path.display() + ) + })?; + let reference = reference.trim_end_matches('\n').to_owned(); + + let t = Instant::now(); + let (transcript, stats) = cfg.run(model, &wav)?; + let elapsed = t.elapsed().as_secs_f64(); + + total_audio_s += audio_s; + total_elapsed_s += elapsed; + total += 1; + if clean(&transcript) == reference { + exact += 1; + } + pre_us += stats.preprocessor.total_us; + enc_us += stats.encoder.total_us; + dec_us += stats.decoder.total_us; + joint_us += stats.joint.total_us; + + file_bar.inc(1); + } + + let epr = if total > 0 { exact as f64 / total as f64 } else { 0.0 }; + let rtfx = if total_elapsed_s > 0.0 { total_audio_s / total_elapsed_s } else { 0.0 }; + let total_us = (total_elapsed_s * 1_000_000.0) as u64; + let pct = |us: u64| if total_us > 0 { us as f64 / total_us as f64 * 100.0 } else { 0.0 }; + let nn_us = pre_us + enc_us + dec_us + joint_us; + let host_us = total_us.saturating_sub(nn_us); + mp.suspend(|| { + println!( + "{}\t{:.4}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}\t{:.1}", + label, + epr, + rtfx, + pct(pre_us), + pct(enc_us), + pct(dec_us), + pct(joint_us), + pct(host_us) + ) + }); + + cfg_bar.inc(1); + } + + cfg_bar.finish(); + file_bar.finish(); + Ok(()) +} + +// ─── write_gt ──────────────────────────────────────────────────────────────── + +fn write_gt(model: &TdtModel, wavs: &[PathBuf], no_details: bool) -> anyhow::Result<()> { + let gt_cfg = beam::BeamConfig { beam_size: 10, beam_dur_k: 5 }; + // Warmup + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + beam::transcribe_beam(model, &wav, >_cfg)?; + } + let pb = if no_details { + let pb = ProgressBar::new(wavs.len() as u64); + pb.set_style(ProgressStyle::with_template("Writing GT {bar:40} {pos:>3}/{len}").unwrap()); + Some(pb) + } else { + None + }; + for wav_path in wavs { + let (wav, _) = load_wav(wav_path)?; + let (transcript, _) = beam::transcribe_beam(model, &wav, >_cfg)?; + let txt_path = wav_path.with_extension("txt"); + std::fs::write(&txt_path, clean(&transcript))?; + if let Some(ref pb) = pb { + pb.inc(1); + } else { + eprintln!("{} -> {}", wav_path.display(), txt_path.display()); + } + } + if let Some(pb) = pb { + pb.finish(); + } + eprintln!("wrote {} transcript(s)", wavs.len()); + Ok(()) +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); let nnef = tract_rs::nnef()?.with_tract_core()?.with_tract_transformers()?; let gpu = ["cuda", "metal", "default"] .iter() .find_map(|rt| tract_rs::runtime_for_name(rt).ok()) .unwrap(); - let preprocessor = nnef.load("assets/model/preprocessor.nnef.tgz")?.into_runnable()?; - - let mut encoder = nnef.load("assets/model/encoder.nnef.tgz")?; - encoder.transform("transformers-detect-all")?; - let encoder = gpu.prepare(encoder)?; - - let decoder = nnef.load("assets/model/decoder.nnef.tgz")?; - let decoder = gpu.prepare(decoder)?; - - let joint = nnef.load("assets/model/joint.nnef.tgz")?; - let joint = gpu.prepare(joint)?; - - let wav: Vec = hound::WavReader::open("assets/2086-149220-0033.wav")? - .samples::() - .map(|x| x.unwrap() as f32) - .collect(); - let samples: Value = Value::from_slice(&[1, wav.len()], &wav)?; - let len: Value = arr1(&[wav.len() as i64]).try_into()?; - - let [features, feat_len] = preprocessor.run([samples, len])?.try_into().unwrap(); - let [encoded, _lens] = encoder.run([features, feat_len])?.try_into().unwrap(); - - let encoded: ArrayD = encoded.view()?.into_owned(); - - let max_frames = encoded.shape()[2]; - let max_len = max_frames * 6 + 10; - - let mut hyp = vec![]; - let mut frame_ix = 0; - let mut token = Value::from_slice(&[1, 1], &[0i32])?; - let mut state_0: Value = Array3::::zeros([2, 1, 640]).try_into()?; - let mut state_1: Value = Array3::::zeros([2, 1, 640]).try_into()?; - - [token, state_0, state_1] = decoder.run([token, state_0, state_1])?.try_into().unwrap(); - while hyp.len() < max_len && frame_ix < max_frames { - let frame: Value = - encoded.slice_axis(Axis(2), (frame_ix..frame_ix + 1).into()).try_into()?; - let [logits] = joint.run([frame, token.clone()])?.try_into().unwrap(); - let logits = logits.view::()?; - let logits = logits.as_slice().unwrap(); - let token_id = argmax(&logits[0..blank_id + 1]).unwrap(); - if token_id == blank_id { - frame_ix += argmax(&logits[blank_id + 1..]).unwrap_or(0).max(1); + let model = TdtModel::load(Path::new("assets/model"), &nnef, &gpu)?; + let wavs = collect_wavs(&args.inputs); + + if args.write_gt { + return write_gt(&model, &wavs, args.no_details); + } + + if args.param_search { + return param_search(&model, &wavs); + } + + let label = decoder_label(&args); + if !args.no_details { + eprintln!("{} files={}", label, wavs.len()); + } + + // Warmup on first file + if let Some(first) = wavs.first() { + let (wav, _) = load_wav(first)?; + run_decoder(&model, &wav, &args)?; + } + + let pb = if args.no_details { + let pb = ProgressBar::new(wavs.len() as u64); + pb.set_style(ProgressStyle::with_template("Decoding {bar:40} {pos:>3}/{len}").unwrap()); + Some(pb) + } else { + None + }; + + let mut total_audio_s = 0.0f64; + let mut total_elapsed_s = 0.0f64; + let mut exact = 0usize; + + for wav_path in &wavs { + let (wav, audio_s) = load_wav(wav_path)?; + let reference = + std::fs::read_to_string(wav_path.with_extension("txt")).with_context(|| { + format!( + "no ground-truth transcript for {} (run with --write-gt first)", + wav_path.display() + ) + })?; + + let t = Instant::now(); + let (transcript, stats) = run_decoder(&model, &wav, &args)?; + let elapsed = t.elapsed().as_secs_f64(); + + total_audio_s += audio_s; + total_elapsed_s += elapsed; + + let ok = clean(&transcript) == reference.trim_end_matches('\n'); + if ok { + exact += 1; + } + + if let Some(ref pb) = pb { + pb.inc(1); } else { - hyp.push(token_id); - token = Value::from_slice(&[1, 1], &[token_id as i32])?; - [token, state_0, state_1] = decoder.run([token, state_0, state_1])?.try_into().unwrap(); + let mark = if ok { "\x1b[32m✓\x1b[0m" } else { "\x1b[31m✗\x1b[0m" }; + let elapsed_ms = elapsed * 1000.0; + let nn_ms = stats.nn_ms(); + eprintln!( + "{} {mark} {audio_s:.1}s {elapsed_ms:.1}ms RTFx={:.1} {}", + wav_path.display(), + audio_s / elapsed, + clean(&transcript) + ); + if !ok { + eprintln!(" ref: {}", clean(&reference)); + eprintln!(" got: {}", clean(&transcript)); + } + if args.stats { + eprintln!( + " {elapsed_ms:.1}ms RTFx={:.1} nn={nn_ms:.1}ms host={:.1}ms", + audio_s / elapsed, + elapsed_ms - nn_ms + ); + eprintln!(" [preprocessor] {:?}", stats.preprocessor); + eprintln!(" [encoder] {:?}", stats.encoder); + eprintln!(" [decoder] {:?}", stats.decoder); + eprintln!(" [joint] {:?}", stats.joint); + } } } - let transcript = hyp.into_iter().map(|t| vocab[t]).join(""); - println!("Transcript: {transcript}"); - assert_eq!( - transcript, - "▁Well,▁I▁don't▁wish▁to▁see▁it▁any▁more,▁observed▁Phoebe,▁turning▁away▁her▁eyes." + if let Some(pb) = pb { + pb.finish(); + } + eprintln!( + "{} {}/{} exact RTFx={:.1}", + label, + exact, + wavs.len(), + total_audio_s / total_elapsed_s ); + Ok(()) }