From c1bb9ea54c1c3159bafcd4fce1fb9cc9d077dbe1 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Sat, 28 Mar 2026 00:36:19 +0100 Subject: [PATCH 1/2] feat: add tract_moe_ffn operator for Mixture-of-Experts FFN Implements the tract_moe_ffn operator in the tract_transformers extension, enabling inference of MoE-based models (Mixtral, GPT-OSS, Qwen MoE) exported via torch_to_nnef. The operator encapsulates the full MoE FFN block: - Router: x @ wg.T -> top-k expert selection with softmax gating - Token grouping: batch tokens per expert for efficient GEMM - Expert FFN: SwiGLU (silu(x@w1) * (x@w3)) @ w2 with BLAS-backed matmul - Weighted scatter-add of expert outputs Real conditional compute: unused experts are fully skipped. Handles both 2D [T,D] and 3D [B,S,D] input shapes. Verified bit-exact against PyTorch on TitanML/tiny-mixtral (8 experts, top-2, 246M params). --- transformers/src/lib.rs | 1 + transformers/src/ops/mod.rs | 1 + transformers/src/ops/moe_ffn.rs | 238 ++++++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 transformers/src/ops/moe_ffn.rs diff --git a/transformers/src/lib.rs b/transformers/src/lib.rs index 6e66a1f087..d8200bd931 100644 --- a/transformers/src/lib.rs +++ b/transformers/src/lib.rs @@ -24,6 +24,7 @@ pub fn register(registry: &mut Registry) { ops::apply_rope::register(registry); ops::scaled_masked_softmax::register(registry); ops::sdpa::register(registry); + ops::moe_ffn::register(registry); } pub trait WithTractTransformers { diff --git a/transformers/src/ops/mod.rs b/transformers/src/ops/mod.rs index 0214a8e47a..d6deb383b5 100644 --- a/transformers/src/ops/mod.rs +++ b/transformers/src/ops/mod.rs @@ -2,6 +2,7 @@ pub mod apply_rope; pub mod dyn_kv_cache; pub mod flash_sdpa; pub mod gelu_approximate; +pub mod moe_ffn; pub mod rms_norm; pub mod scaled_masked_softmax; pub mod sdpa; diff --git a/transformers/src/ops/moe_ffn.rs b/transformers/src/ops/moe_ffn.rs new file mode 100644 index 0000000000..77eaf41b54 --- /dev/null +++ b/transformers/src/ops/moe_ffn.rs @@ -0,0 +1,238 @@ +use tract_ndarray::{s, Array2, ArrayView2, Axis}; +use tract_nnef::internal::*; + +pub fn register(registry: &mut Registry) { + registry.register_primitive( + "tract_moe_ffn", + &[ + TypeName::Scalar.tensor().named("x"), + TypeName::Scalar.tensor().named("wg"), + TypeName::Scalar.tensor().named("w1"), + TypeName::Scalar.tensor().named("w2"), + TypeName::Scalar.tensor().named("w3"), + TypeName::Integer.named("k"), + TypeName::String.named("activation"), + TypeName::Logical.named("normalize_gates"), + ], + &[ + ("output", TypeName::Scalar.tensor()), + ("router_logits", TypeName::Scalar.tensor()), + ], + deser_moe_ffn, + ); +} + +fn deser_moe_ffn( + builder: &mut ModelBuilder, + invocation: &ResolvedInvocation, +) -> TractResult { + let x = invocation.named_arg_as(builder, "x")?; + let wg = invocation.named_arg_as(builder, "wg")?; + let w1 = invocation.named_arg_as(builder, "w1")?; + let w2 = invocation.named_arg_as(builder, "w2")?; + let w3: Option = invocation.get_named_arg_as(builder, "w3")?; + let k: i64 = invocation.named_arg_as(builder, "k")?; + let activation: String = invocation.named_arg_as(builder, "activation")?; + let normalize_gates: bool = invocation.named_arg_as(builder, "normalize_gates")?; + + let mut inputs = vec![x, wg, w1, w2]; + let has_w3 = w3.is_some(); + if let Some(w3) = w3 { + inputs.push(w3); + } + + builder.wire( + MoeFfn { + k: k as usize, + activation, + normalize_gates, + has_w3, + }, + &inputs, + ) +} + +#[derive(Clone, Debug, Hash)] +pub struct MoeFfn { + pub k: usize, + pub activation: String, + pub normalize_gates: bool, + pub has_w3: bool, +} + +impl Op for MoeFfn { + fn name(&self) -> StaticName { + "MoeFfn".to_string().into() + } + op_as_typed_op!(); +} + +impl EvalOp for MoeFfn { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + // inputs: x [T,D], wg [E,D] or [1,E,D], w1 [E,D,H], w2 [E,H,D], [w3 [E,D,H]] + let x = inputs[0].to_array_view::()?; + let wg_raw = inputs[1].to_array_view::()?; + let w1 = inputs[2].to_array_view::()?; + let w2 = inputs[3].to_array_view::()?; + let w3 = if self.has_w3 { + Some(inputs[4].to_array_view::()?) + } else { + None + }; + + // Normalize wg to 2D [E, D] (may be [1, E, D] from unsqueeze) + let wg: ArrayView2 = if wg_raw.ndim() == 3 { + wg_raw.index_axis(Axis(0), 0).into_dimensionality()? + } else { + wg_raw.into_dimensionality()? + }; + + // Normalize x to 2D [T, D] (may be [B, S, D] with B=1) + let x_ndim = x.ndim(); + let x_orig_shape: Vec = x.shape().to_vec(); + let x: ArrayView2 = if x_ndim == 3 { + x.into_shape_with_order((x_orig_shape[0] * x_orig_shape[1], x_orig_shape[2]))?.into_dimensionality()? + } else { + x.into_dimensionality()? + }; + + let t_tokens = x.shape()[0]; + let d_model = x.shape()[1]; + let num_experts = wg.shape()[0]; + let _d_hidden = w1.shape()[2]; + + // ---- Step 1: Router ---- + // logits = x @ wg.T [T, D] @ [D, E] -> [T, E] + let router_logits: Array2 = x.dot(&wg.t()); + + // ---- Step 2: Top-k selection + gate weights per token ---- + // assignments[token] = Vec<(expert_id, gate_weight)> + let mut assignments: Vec> = Vec::with_capacity(t_tokens); + for t in 0..t_tokens { + let row = router_logits.row(t); + let mut scores: Vec<(usize, f32)> = + row.iter().enumerate().map(|(e, &s)| (e, s)).collect(); + scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); + scores.truncate(self.k); + + let gate_weights: Vec = if self.normalize_gates { + let max_s = scores.iter().map(|(_, s)| *s).fold(f32::NEG_INFINITY, f32::max); + let exps: Vec = scores.iter().map(|(_, s)| (s - max_s).exp()).collect(); + let sum: f32 = exps.iter().sum(); + exps.iter().map(|e| e / sum).collect() + } else { + scores.iter().map(|(_, s)| *s).collect() + }; + + assignments.push( + scores + .iter() + .zip(gate_weights) + .map(|((eid, _), gw)| (*eid, gw)) + .collect(), + ); + } + + // ---- Step 3: Group tokens per expert ---- + // expert_tokens[eid] = Vec<(token_idx, gate_weight)> + let mut expert_tokens: Vec> = vec![Vec::new(); num_experts]; + for (t, token_experts) in assignments.iter().enumerate() { + for &(eid, gw) in token_experts { + expert_tokens[eid].push((t, gw)); + } + } + + // ---- Step 4: Batched expert computation (conditional!) ---- + let mut output = Array2::::zeros((t_tokens, d_model)); + + for eid in 0..num_experts { + let tokens = &expert_tokens[eid]; + if tokens.is_empty() { + continue; // Skip unused experts entirely + } + let n = tokens.len(); + + // Gather: build x_batch [n, D] from selected tokens + let mut x_batch = Array2::::zeros((n, d_model)); + for (i, &(t, _)) in tokens.iter().enumerate() { + x_batch.row_mut(i).assign(&x.row(t)); + } + + // Expert weight slices for this expert + let w1_e = w1.slice(s![eid, .., ..]); // [D, H] + let w2_e = w2.slice(s![eid, .., ..]); // [H, D] + + // h = x_batch @ w1_e -> [n, H] (BLAS-backed GEMM) + let mut h: Array2 = x_batch.dot(&w1_e); + + if let Some(ref w3) = w3 { + // SwiGLU: h = silu(h) * (x_batch @ w3_e) + let w3_e = w3.slice(s![eid, .., ..]); // [D, H] + let gate: Array2 = x_batch.dot(&w3_e); // [n, H] + + h.iter_mut().zip(gate.iter()).for_each(|(h_val, &g_val)| { + let silu = *h_val / (1.0 + (-*h_val).exp()); + *h_val = silu * g_val; + }); + } else { + // Simple silu activation + h.iter_mut().for_each(|h_val| { + *h_val = *h_val / (1.0 + (-*h_val).exp()); + }); + } + + // y_expert = h @ w2_e -> [n, D] (BLAS-backed GEMM) + let y_expert: Array2 = h.dot(&w2_e); + + // ---- Step 5: Scatter-add weighted results back ---- + for (i, &(t, gw)) in tokens.iter().enumerate() { + let y_row = y_expert.row(i); + let mut out_row = output.row_mut(t); + out_row.scaled_add(gw, &y_row); + } + } + + // Restore original rank if input was 3D + let output_tensor = if x_ndim == 3 { + output.into_shape_with_order((x_orig_shape[0], x_orig_shape[1], d_model))?.into_tensor() + } else { + output.into_tensor() + }; + let router_tensor = if x_ndim == 3 { + router_logits.into_shape_with_order((x_orig_shape[0], x_orig_shape[1], num_experts))?.into_tensor() + } else { + router_logits.into_tensor() + }; + Ok(tvec![output_tensor.into_tvalue(), router_tensor.into_tvalue()]) + } +} + +impl TypedOp for MoeFfn { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + // Output 0: same shape as input x + let x_fact = inputs[0]; + let output_fact = x_fact.datum_type.fact(x_fact.shape.clone()); + + // Output 1: router_logits — same leading dims as x, last dim = E + let wg_fact = inputs[1]; + let e_dim = if wg_fact.rank() == 3 { + wg_fact.shape[1].clone() + } else { + wg_fact.shape[0].clone() + }; + let mut router_shape: TVec = x_fact.shape.iter().cloned().collect(); + // Replace last dim (D) with E + if let Some(last) = router_shape.last_mut() { + *last = e_dim; + } + let router_fact = x_fact.datum_type.fact(router_shape); + + Ok(tvec!(output_fact, router_fact)) + } + + as_op!(); +} From 8ab07aae71990bf797d541ecc821e09fdbbaf79c Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 3 Apr 2026 11:50:32 +0200 Subject: [PATCH 2/2] fix: add moe revamp with tests/cleaner ops --- .../moe-ffn/qwen3-tiny/graph.nnef | 39 ++ .../nnef-test-cases/moe-ffn/qwen3-tiny/io.npz | Bin 0 -> 900 bytes .../moe-ffn/qwen3-tiny/moe.gate.weight.dat | Bin 0 -> 384 bytes .../moe-ffn/qwen3-tiny/output_0_w1.dat | Bin 0 -> 8320 bytes .../moe-ffn/qwen3-tiny/output_0_w2.dat | Bin 0 -> 8320 bytes .../moe-ffn/qwen3-tiny/output_0_w3.dat | Bin 0 -> 8320 bytes .../moe-ffn/qwen3-tiny/runme.sh | 8 + transformers/Cargo.toml | 3 + transformers/src/ops/moe_ffn.rs | 619 ++++++++++++++++++ 9 files changed, 669 insertions(+) create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/graph.nnef create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/io.npz create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/moe.gate.weight.dat create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w1.dat create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w2.dat create mode 100644 harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w3.dat create mode 100755 harness/nnef-test-cases/moe-ffn/qwen3-tiny/runme.sh diff --git a/harness/nnef-test-cases/moe-ffn/qwen3-tiny/graph.nnef b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/graph.nnef new file mode 100644 index 0000000000..2cf056ba33 --- /dev/null +++ b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/graph.nnef @@ -0,0 +1,39 @@ +version 1.0; + +extension tract_registry tract_core; +extension tract_registry tract_transformers; + +fragment tract_core_properties( +) -> (properties: (string, tensor)[]) +{ + properties = [ + ("tract_target_version", "0.22.0"), + ("torch_to_nnef_version", "0.21.0"), + ("torch_version", "2.6.0"), + ("transformers_version", "5.5.0"), + ("os", "Darwin SNS009332 24.6.0 Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:40 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T6041 arm64 Darwin"), + ("hostname", "SNS009332"), + ("user", "julien.balian"), + ("py_version", "3.12.10 (main, Apr 9 2025, 03:49:38) [Clang 20.1.0 ] (64-bit runtime)"), + ("export_date", "2026-04-03 11:27:03.768935"), + ("exported_py_class", "Qwen3TinyMoE"), + ("export_cmd", "scripts/export_moe_test_asset.py /Users/julien.balian/SONOS/src/tract/harness/nnef-test-cases/moe-ffn/qwen3-tiny") + ]; +} + + + + + + + +graph network(input_0) -> (output_0) +{ + input_0 = tract_core_external(shape = [1, 3, 16], datum_type = 'f32'); + moe_gate_weight = variable(label = 'moe.gate.weight', shape = [4, 16]); + output_0_w1 = variable(label = 'output_0_w1', shape = [4, 16, 32]); + output_0_w2 = variable(label = 'output_0_w2', shape = [4, 32, 16]); + output_0_w3 = variable(label = 'output_0_w3', shape = [4, 16, 32]); + moe_gate_weight_aligned_rank_expanded = unsqueeze(moe_gate_weight, axes = [0]); + output_0, _router_logits = tract_moe_ffn(input_0, moe_gate_weight_aligned_rank_expanded, output_0_w1, output_0_w2, output_0_w3, k = 2, activation = 'swiglu', normalize_gates = true); +} diff --git a/harness/nnef-test-cases/moe-ffn/qwen3-tiny/io.npz b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/io.npz new file mode 100644 index 0000000000000000000000000000000000000000..bc22dcafa08e83619f3499936ab5c7aa8b627acb GIT binary patch literal 900 zcmWIWW@gc4fB;2?JAUo#|Dk}JL4+YQub{Lf-as#}ppub6fWd(gq7X(;_6zk5h-73a zW2jb7Ni9w;Qnyl2w@EWm*HKVU%P%S^O3aJTFG@)TiMu7{6sH2ki!%}nQh|I8LmdTU z9R))(O&x_=1+oCwiQQ}LthvqheM_Ic?@znf{=}M1cG5hW_G?}w?CVlGv`;b7#s0-(b?+H}zANy;1Bb`>Kv* z_CmIc?Ux<%-B)Dz&(6{Dp0RTOr>$mo znf)9GR&H*ycX)nmpSNND{`39+_Z5hD?~l(;v|s)1nElI41^d^VjQ8<*p0vHxuzzoW zZ-?EzBcl5QPy#>7Wx*s^;PU_jKfknOP=`JP!-4vtj03zGnRJ;^6F(?rgD^KVOMs{b naH>bj3g{X^5eL)AgQBq(s1-RH1H4(;Knj?Euoy`HW&-g53M>d( literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/moe-ffn/qwen3-tiny/moe.gate.weight.dat b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/moe.gate.weight.dat new file mode 100644 index 0000000000000000000000000000000000000000..04591c6a15b4a7b65a2e47570ec7048fd7adc83b GIT binary patch literal 384 gcmeZ!&&a?4L`)0}3@kt_0K~X}0xpSBak4`I04Znzb^rhX literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w1.dat b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w1.dat new file mode 100644 index 0000000000000000000000000000000000000000..da2fe89bbf1c73ad631eec6bc79e8df927b226c7 GIT binary patch literal 8320 zcmb7}U613)b%r~LuXL54P-epzYM>@;Uj^-Gv6EfTcrHBN-RVILj2oO1S(Ivj6&AHFb=6Orq0RlJ|DRhezKk9tF=RK!NAO7N5_vte|098I!eSVGsI%y#dB<5JtEun&tJ z?^joElFLh~swy5Q-;RFn4rY)VyB}TWJGVA|HiKQpkT=HgAfn#s{+@|*G#$6+t@-V0 zeY-OLer0Bq;jfnby4$V$cr@9^V7evgjn-^UJ7?SSVCoKUQ!w@ibQqa9ugx|lV>N#^ zL%e0S4>7ph?wRz_S07#e?o+e+{x|>itH8uE&yOxh*2}2ZE5je7cdON^mp;bXeR3h= z9{-tp*;SjqZtKo9_)?H2E;PE@PGnF4@lA^5Vg1YNCH-!g9<9 zKj-oMBKm(F6C-bHu2D^++I zniOs3Pr*mU(!5D~`Df=6fmwpzhgfXHV7;#AHRQA4Ggb*B!4u086P%e$HFL2HKR5IA zm%G|Zwubzn`5$pBF7J4`|Lpry`0V9!Gvg55M=8@0*yQ>KJ+}^n2NU9DmW$7XLxnwi zJ1Qo^TG$hPJdO&u$EMNe`rvFYtMvd7%$i^EcFoEFTXRn!f!~bFcbfnEUw-+;7H9aO z=>3*hVsHj8GQn_E9teko-T0p!`OrJ!+-IAeROsOx%tM7Ta zlW>A1{d+XcX+s>%o@){DzZ7TumBl7)6#{^)6XuF=zN3FjuH>Q;GKOI;_k};@=BhZrFlS79@5sZr?5=Qt z4_WtRiSQsF$tJudrJdg21BkWREjr;|8jPMvm2k#)uQ-S=4o&YQBbyqGr3(BJZN1XV z=L0hJ#Rv;&aVz;=`XIeQWThOEk0cKl80DXG8`hDjMMX;Xs`|kRXW}eiyB1#n zMQ1pMyuPg<0vN^GN9zaE?MPODNaAY9p$t-me87$;o0M3nUVMZp_`W(S|D<2+U8@{a z{DT=@Nk`PUa0h?>^0q`@<#Qv2{;) zup47=sOH80YuRy&k@n|vD^5`!I`JMDsn7-gxVbJAKgv$IxJC!zRk=^bJGmm)D^so3 z1UL@9pxg|?6?cly9WTl5f^kURjbub|T(Lv+lTp7B=gaSotoq>K4{zsh7yI-4iTNZtGxy#xYPiN!WQ90;uZ4pb>e6VA^oVdNEc&(3b6IJaF~ z^o}qt&XtVfb5x>4=&Q#)?P}GwnE>XwwQ`2$DDQGp#FOweWy?*fmk3erp$u`W(VBsk0o!UgBypn+sQR{3y>J8@G_GpzF(*4^irH zt*AaUp6k>A6N^4GjEe~Gjxn{c1BWvqK(pkCzw%cqcz3GS_ULXTQ@JO5FP!!!>0H>> z5jo$KJjJJuW+%8WUl47~U#1d^lL~}}Vg`Yc|DmJS!6q)_bh$))YLC^1HiwS1cc{=| zCRPq!GL>?-=98$x?ICs5qJnJ^BG@wK&1hAxL5=Pq+S4xT(kQ<(RVP1`e7rcHTj^11 z<+<93PbCj%_9y%t7YaV=1=rs{^Jrqe&&q4PBZOmmt@%!LcOzt{K2JVWtwQuTey*}) z3Z%ecrVMQhzhC%-z6_!Qj*?3&2>?L@4wX+;55$?qTo)Y3@ua~9`YpYj@C_N9*#h;(sd;-q(!2O<7s-AUO( zqbB1}PI8q~wDl{gs=hx3<)fMJAS!oztNyQ%t(rq~9nnMbBwLD)Lo>r)uIrX#@dKeF z3_Ma*rO)SCf$-M6{b%~SGT#Or9DJnm0;CFush)iU0A};kQf#DQ`pAjuM8iqRCG;x! zr#Z0!oj0j#C*gh>FL76I!fomkJUVXgl0^U@{Qd7B7hKkU)0EmbkwhZ(0ASqchiZd0 zr+lN_J2tD8-9z5TxuhiS6h>N7ziE)G`MFjLPdJi(j5+%DDF5rcn$HBh9qD`4&b*nS zqKP(cz-6OcDZgtyr5=g|Dn3v=FxeUw9-3dNW`Q80@9)8oatZpE+?DXGUMUcJ)F}=Z z2cUymN!uM#6wfG5ir3@J@1EDcP z2u<+;kYY9#D6*seTdJ$)r&iJ$jKdwllrCy8b)Rvq1a#)tTj5WLFZrkL z|BAMlhr##1)YbdIs#%lO(nI}x>779`=Bur}_<*@ZJ7ADkKj!?Z&-T(Eh8PrGR$I0c!(QTR#;27QVT>+0OrQ}WQm9V9bpoqML}t96k4 zlRZw!PM4mX7cxs|4rS2kx7AzgU3-kIBCQBgeG-lB)DEXraxG7z!Pb0+LnsZd=Mgpx z9$l(g%Hvl$<&i2lD@>7_RGw%k`s1*^)&2mDu5UYA_;8GST5s&%(`Pc+g(&+CAk9zU zpp8pLTsN}H*M0Yt{2$D>XWZjlU_C^~ndm=S$}c97<8^L_WFP$0%IBmJ=83nRwik@{ z2LjGxYL@XZQWzpVUYHKeq|FI(w1t%?)pIbq=h5v0h(AOYLX1pM2}{I1eeofLf{+@2 zxh5?^_2Ed&_KQvOc{H@;bF;Gf!d3GB+}Nq7P7{GG?K=w>W>5k!psX&%e#sPA&Cl|A zMh_mKt7SNR44Ed^3}sC@*}Z*j(>~jJU2~$^MNv(Cs~&Y9SZS4cS-%Y8D#@Q*=x42k zp6hOyiI!>^(?87B8c**79ntOI!ry=*!8=$;K9X;9@v-gC^wwow8jJl5Cn|S1QErtz z9@e;vd=65|zJsUetM+Ptpk`sE9D;PqdXUn{;(Rw^sirZwuJ}6qX09SgnoxKFm{d~w zbRK{h97=OWK#=e(Vp$^3M>}+ZgB%St03~23WWdg3eX9z&O>1Zc1yWYpvt81($wULg*eYaVRgUv8G2W z=X}WL+Gp$zy*9!UNyN$^?eFD$Lw@&EYBuSfsGOu($wwAEEO`?Nyr=1*Z;=m>O@-en zIv@V^mVOR4prS}k?Y`+b4aNIvJkz9dpCO+Y4kTa1>2up|F#TTl5baz|5)m?NSxB?D zl8tLwONG7iQqO59&zC}NEEpjpX-`*tB7G;FYkv^)9H!mqJ6-pKV$LZy55A3f)U<=C z*XvEOo$K_t(gBK_Lf_mf%_JK?v?^Z8OOcX3m4{GZiTz944TX*eB_o~Q0Jh+cQtnBgQIBF%!*>|kdr-ncwc@PQ4g?)Wytx2(Fk#$M$AV2kq7QwAP<&g8Rwkgd>u7Ka1D-KlwhiBHScbGroA zU?@D4JT?Q-rz1Pk3o)%?j0J42XuwK3kY*bKQi*p92u<~^g_LxUl-rrz844u9_HZ%2 z*JZ5t)1iH`-9v*b6rU$;8jaL;Hjd)Pyk_qt%!z>Mc@0z+4f^;n*ZNNF`LbhzGtX3< z&?3zpWnTAAQ!JdcfFJQ(!TB{ZuubqM)m3NS(#|70 zuo}bBuQK_b!E=jwZ_>Bo*3h7qX$n`A2wHu(AIyX~QA`z_tBdcC&G#ZR*P#ikXcC0)n( zV zKGDegrb1s>oAH2R~b z^)A%AQbXCR!(!K5@Jc=98>_lo8bEdrx`;O#ekic=d!cSTfav3K7sqeey*=Pd$zXfS z6MMi-r98)lzU3@DlzS(70ya(6SM>y()Pq%AP!485fC`V&!T>g{a%@Kp4`!KW<4{xCw!#9y{DBMYX@iiV;5NJi?ZHY8f-BOxxnkyu(Itqrct92fQqloO&K>uR_F zB@o>-ci;f(fn^G(?QilqKqQKzSB+8SA+_AN{k?JtTQ|ya`86M6xHfM|PZc0g(%PBy zn@+?gxg8rB$b4gVu+$3H>sTsB(x6uEbI6G9*yV&p+~xNsY?z#W%cF)5>pIiGst;ny zC`==L)B&? zKa_B!FQuk+F+qJ^GM^7l`%5irE>AK4l#qACFR+Me+p;9ED4OXl%8a+%a*owPBBmWoa+bP;-D3*fYw cn=y%ZpF>hGEj@%|Q7{-RrtRnYtp#BJ4P@Z=TL1t6 literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w2.dat b/harness/nnef-test-cases/moe-ffn/qwen3-tiny/output_0_w2.dat new file mode 100644 index 0000000000000000000000000000000000000000..6ea08496fd0099c3f5eb1f731db3e73d1899ef63 GIT binary patch literal 8320 zcmb`M+iu)QddC%HbLve1KSH5}V7na-N1K}f0y=B2WshXc*q$9Pb^}4+l*l5RBc8%4 zN@@=Vm`5=31oH?2a=x2KnCo24Ma~CAcS$yf&F^2WHs)fp&czb8$R=6!o&Mkd`>Oct zUw>(D&+P5}^4{LwxA@QH-aY>Rti$}Qo4iK# zAN=!4qA@!3T<6AyO!lm^+hHH0yv6^o@&7b-#%%iOS>L&JT3&C*m(6be zx1GywaUZ9Hp?#CSH(HG+g!Vy;joAOi-<~8l`2RKb-}@iV`2XfFiNkL`O}<8cjMwCA zKMljMUf}x<@c7fW$=y9q%>DaaWuA=Ad%3=FW>udh<`(;9`>o{vFO+#tQ>=p7|I5hOf#0gFnG; za!N9)hh}XD_cpn^=aN~6H0W&V#Ms=e{p#g- z0PmO%jFbN_2C#cFLdm-$DzXANcCiVy9-Od*;i2n`v)4Po8$GWQh@^IEsdvxz@a(B=5?mPVuI@?6u zQv#uT_k?}8Zx8Of75`xW@EV=|XzTw2xbNO$$scXq)~|eBK}6VcI=|F;H~-ttwRdo} zDJ^^-$2apcC86wmjrNzVMj0?t9rMZq6@Uv@u+-#$_ z_923rv9>GBg*PiVU(VL==6zw-RyMlxsV+f52t>M|Qa#MQgw zyLB+7IqrLn;Zft6k#aFLMV712fOC9)SAk_u1^J2WiXv{uJIRsQIX;cHy%1T z?wWGZhwCK1r#ihm_RO~aw|(ZJ^Uepz;2dETM!nYA^CnOphETT}!+TLLF7cOZqWzDr zEoGpu%)=LzIk)74$=r41WtSBj^a}MW7jXt&Ip@7>1_IB1%Fzyu8!4tIX7chitX}nF z$V^dtVccB@W4(_IliQ~0F5Sq8jM6vLc=JPV7dX>a&J-rM^`$dqx2jN`5G-e{?d%w2 z^}f-)kd>o+C3|dBzpC7!dZB&At-;$pklfj?cPaahc{QNf>_bmk{3F@Da35!c^a1 zc#7ebN-A!06%ObwYr$^J!*#|SNg!dshH;B?E2V_z(6=@v#->02WKH+W<6wPF8&+D< z3v$Sw<){_CfbsaIP>lIk2+yI;i4HN&g8_M0*x^T%sGF>GVKJYY=7{#HlD^b;Z_YQS zNX$B|=%2{PrrFF=cr5Ry#({leGc!|H5dP`q^B+8NWQ?OP(D*fk<1TF@TgjwW`s=>( zyudIYO!N!wd;tF0@zEbVv*;ei#W_^np`V~*S9puGbkbkroxoQ$Xb|dsuq_;wwr(DN z@I2$u4A-Mq4M&(9=;RlEnT#l)eAIXy=e9AVLCp6b{@B>8?3PC-A3fAFTyLNosJgL= zj|O#+kDe9OPo0&y@`EFybd+no^ZoZ9&lPw5k$j-9_u~6f8P#UE{i<-KakhwlSoc>B zHGB0&iJ8xZImJqvMlg_y;zr5bH!q*Bp~lc}vc5C}<1DF};K!ccC|jDVS@S5(8Z-ayJm4qBZ#n0DGaC!Dz+*u5!7K2QxAV>NMLhWN}?W zCYK!`1Tw3All6{B$*10i;Oj!}sg&dNvuoj>bcCJ(vB-%0?Z_2O=#*gMP+^}z8IZE0 zY3ZG|{U>g})LNpSJUEzI$Sq&6IEZ`*0m^0K&?21~mE2~@Yl2Z|o) zmko_tfZU|$WUhL61bW3c;!knHIa#)Eo)FJCkCw|GK8SOYenU4m{RO3vspm~$+ELgd zdGJpvPEbw7zxA+z-ju5(;I<;@}g*BcFNyDa;K`QV*)YHj>k1;MQ!(s(P1_^pNpJpAi$|`z(FZJ zW%MwtYa*}kfO!L+*!I@62#AiHPnS=0|drIBQ=Z()+-g&SIQ!i!d0Aop2$2vG0Nk7JCS*o}{%K-G_3H zK+I!}x7o#2H+wMPmZkG9^n=fYiwsiCm_YH+dleU?`=x$EIlb&->P-G5CDw|+oHH;~ zT>{iS^Ejvr+csMqBId-sD=3Gn=Mz(bD#p$4xJhzQQ1ZAPdjjdx<47oRMn&l$mhPxG zSHL!eEd`Q{>r_mtV|{4ej~91;q!0NRf90%G76@?W%1X!Wab*QAT=HD9YahUHwX5w> zpMKi*3WVung+rj8JrHNm-ZrwC7<^+kD| zzSl&}*IqHyn6U1&AN0j^H*Y-iO7C4VV_PoBo^|cNr#bJ!XHxn+T+$=+1Kn?2 z#EiMaj=}xg-X$?t!rz1Fs=T!Kq_#Ku6Y7V43(qqF)@ovA2m;ZN5>n*NG^DQz!t zNiW+KYiePp_z62oc3Qw#dudLnR7yAJ(UrLY@S;q{$t`ioTDB|J+J991geS>hU$KZ7 zAXT-#q9aF0w7)ypF#prj@}i9 zbe5mwYhCl&5;_dTfWD2?~lcqMCj;nOEa>^MZ;-izx%}HqM@$vX@;?} zuhI4;P7DN9xxXTaWa=o+0ajTNIUS;-hy-eMM*mSzNGUzJ;|EJFZNrH=&BJb^5@31^ z6z9B&KNTJQpglS8opEd7*Be;{v7y7Lu<__F%?)D%`n)DO{rMsl=a%Z_H zmkY3p`OZAv;!Sn%?si3od!_ke&qJ3k{q!whQ8w&uZ5yn>1_#7H&}PRNN8VwKYM+9j z@z7t&_OtpK<&9@MbzHQTwM0di+EBOQk=(^~V-RjQV(p`U(O$qyS_f!`+;rEl;S=pq zqE4|gej>f~xtDmbmO~vZB-Lo{e_xHF^KhM!$m=4lo&QVXTJ)$m6qsQ!66L3nME(rvZDRuDIr>Fr1bbUWk?bvJ5b^^=8hZ?>%QA}0mO1lg>K6!2Boqj=s(4O8RZ@pcBc; zo*TV{GC(Nw!U=lBGnd=-Dml>z7wUuGFbYqD>g}5v(jGR!A4&&dp2wysQ{qmK>JDdw z)2R>EeU9kNE0)~i5!MmxT%Z>KAl`nNX0qebU1c^dJorI4!>jdm%p;%|eQ7u593DyK zFZCed^dZEny@EPey52|)B!UFe^021qVhz)`2HjJgo6B(=>OnR*Uu$)+)T$Q{r;7vh z+>;!?w=@xRtYfx$tN6!*neKX!tf*V@n|cAwEDq2sl(bp1O>46{1(y)!i1YMJ`Mc{u zV*YTi1gLr@4$wMI*8>3(3+Y;bQV+t_6XmaRMmewi;hv3h;UsYYtIEiWZoFizltm5| zmv!tJOiF-JNF$d3$ww}nvK_$s&Y!Z9BBS7d!fkVv^(j7$wMJaVytVrgeMnAEk7D#d z0|U+iomhQ+Hsi|%-BnhiSCmPsJf|FS%?#W z5A*q;wX}Kw@tycZ5Uj7+9gF}*KMI<@Rr0kn>Kh%XdGQB5O%pz+m77Ec7_!sSACU7s7+m{Tj(KTAPA5FdL!D}wy^{vd4+s`u{n`}9Uz1>>?W≪4_mwW|8ffB(a?XXe?ne|+}r*+23BpZU+%zkB^( z&;Rb=|K~%)<;}Y{e*5BLTc7x6w^P%Yk7IZ8Ch+=IS9vp^UmP`dy=vb$(+zV!Rj%hx zv@XuxeKw}`uD3Catxdr@jjN{fSh?A4nCz|V8GCPv!4>(Xo8;@gQ@o2kIGDF!i8-o#|yJSYz zlb+?s`_^$T9nuCHVKah^`#brK6)0|DyjRN0xx7low z#nN?+P*+uFZFZ`QNOKx5LTiDSJ_`Q|y&+|N?wWb*}&FL0!VRyN^)dOPD81%l( zhK1z^pUm`PcRY@jq~l0DvUzTZp&zmg063g`7U%}yI%9x^DlC`n;uJZ78yS>}pk!eX zLxye|ooriH&LMCbe=c^wD8WrrXv&S*lLOOe{1TUKqsY$yOV}3>C6Y%;MoHGU_kO&} zEF9RCT~HDbWmW#@zaMPd?{3YD3Q2Ks4jwowuD0)@@=oz&Az@!@6`0w5An#!|ubHAX<+k#8Ot%|e z#XSMyl?lO*{iErGJ@F)pBmoK$qf2p58ju5~6|u>Hbx!5&BwaK|D&haGi?$Q@jLif( z@4Iquj>45_O}WjKM)PN1F&eLvA<6vmljceeh}caA4#(DJ+KVm?Dk}v73GPLK#rqTz z9~v#n#hlCsstuAOf9VbBosP;e;R;#{W%SCeKs=hx*NyJOb4*r=*2F7YY<%a_*?bQv5TKq6 zkVzxli3)qwOwCVXehCLx{Ijy1 zZ?^%`FXdV09*-|BJHGt%vD}EC&b%buWJ0e$!i$_Q+0fcqpRw&dniR;g`tRXkjngBB z8K78e{=&y_H^tgm>)}jWU(8_&U8`sEp`Ce0La3hV`|(IQ;C}V*tI^lkWa`Z#ejPoqB`?HxX+LmB5#vxV%>bKx%DQNb=35p6 zF^5cpiW$$EvW~LTVLYL7`V1YJ4$Z2LT8+ZWQ!2!$-bAvp4!p?x!SWgt;aJsG`3Ln5 znD+^{RD^nii49>jeX}PJUVVx31n+eeuJklK6>biJTD5_fzR(y$0|>2f60Cn zV2s&qr01!NW@4gLelv}p(v)kzd@M}j7A<8kF>{>!)|AU`aO2y_5?leuJ6sMw6CGVy1`G;biAavfP!2%nZ-q3Fb^dY z^)<%ZqaF%h>Te$R=3W}o?u}wq`E5XFH*6!Z{kiR^^XQ`YvQs|3HRH4(!qla`%>Hty zr8ufSpn~zC+^B8}8_G+~1AdQ&fV@Z6LzhmHQ&Iz6f6jna+KYCTC(=(o2yUhiB!iVM zOq88~aGzi%jgd`f6<_X>$N0JKV(u*wmB&bT&6~xKJ99%OXMRyT1?^`4n=Vs-UR&7I z@)-Hng$XdYflw+*Ui%AJMZwtJWxUF!oTMzU)k49%VumxX$WWYt*a%eb@Q>=0`7=f| zG+4M}E8i>D_&41AA@lv z%P;gT97^M?_4^c5SGv(Gb(DwnMp1a&_C)dWQ4m?BNjDI9iVkm*6C3Fd<@>NVXMZtUaf6HzCiG4+=pVDLT6cjGIW4Er z7Z=4lUr{fFSFM4{n@1W@c1yXo^birtvZWlh^SOMlJeJ%W+97QKrN3;U;HyR3aic+i zl?aU9DZYdTzi+973#j!Bcx+^=au{(2LRGJ@bZ&g7!H0FT+ulPZR8-`pkC|ejTyS1Kl)uPG` z7Bf>#=)bxe_n8LgM1mXhc77FlSS*Aoytc33uEXY2*Z($HRw}I|xXAHj>-K zhIp(Op|!MJT;0fihGzRb#%reMF}Watk-T5klQ#a?v9w`h^!&|-ch~h@1V(CBb8oD1 z0qTTcvVGfX!FshfHaPXdOX2BSflt*LQsc$}g-e!mjO*pAY2Zmx$qieT)S)IOBe{lw6e|h^eeZ2N@7#_FNj0zNkzJ+`v#tKJ^uC;^3Ac;rDaA4%_4rJ~+UgI1kloO=P zg?dgQz=`9OEcH32zt_KqP0wLBpZ~S&PWKsAWz_mo9jCHCP%0u0LAz_UnE~InjXl3 zZ+#z8de6y+o?99L;cs!E9)K3c>78;ndBguTF=#k4lY8mFpxJ z4j&#@uHPZAg<70)o+DWHfrk1N zUb`dN;AfyIGWF)8FaP?Vnxfu3=PtEPIiE-ZHthM#RcCHoi!A(8UNOYDC*JoJ=g%`| z3)c(&A1+@xhvwD-iR_Qpkev@B6(@ZA=GBjXTj!McPn@HipTc$%iQxJUY-HQypInD{ zk+sqQu&NuoEA>c!;EUP~8JpeEebE>$x1f!lb*i(zG#WVSw8W}8N14l(C8|fK9|34& z4WHFl=z>x@w)E`-2j_-A;pe3f$uFUT#!Oc~xa@B|M@BxgiEX{+O6+~i{fT@Xfo2|8 zFEP`evg8cYg~+LXjQX7AbJuYe+m@M{%9+#H=KrPvjqWopV>i}nedmUUbuD_KbI+qE zeP#r&{+ohr0(DpNM+$LXhdV*9(S>9I(|`@wlDDk zIcxoR??N4G*XA>7&5_NLK`5?Ea=c_+=2P75wI-*`IS&WF+v)ID^B~S4P(;md8}pa% zzE_{x&7{KhJoi5o4t|UXnC4qf6iMI1;XQF=FX5I`V1atE85!#pIW${!!QCU>)q$uq zxHjWGEknER3n#+W()3Al;;Ox8_75=3@el?MY-C=({QRK0Is0!%wj@s&T-RCGHvL^Q zuWrzS>U)~P;)WVLKE;b40Y-IqEkv=-)uAnAEpuR3sVA)2i?wvpr&77d&E9W!`zuTd zcg%hGI$wAE^yCPwt@AmW&K5UVj^Us38Pw-|180(ftyFt>3zgx09-Jo5BiIuoO{^|F z#PeOnn5UdwLqH12XU>s~nyc$~#%pa%Z%Dr+RY>0}Uo}%I2_ntq$=tz*%WpJ2r_R7n z^7HQ@9IGxhuDi?mraa-~aoLM}wN_v&S@VmXV~>;<2Nk-@{qno&y)c!rk0JbcGINBJ z^P~Fw;MUyZm#|rO;)1-dX~ykK$zj*7b8iZN`Q6&%t}BHB4x%VD)a?*OOgPeEaY*KjPUm2x+5Us_B#lV$+OwUq#>9Ra0 zzfnft;nT_Hyc?lyn``olZQ(fWINW${-be>=g`7n?WPd|HDDBC(moIhhsxchJjPj#C zU-`$}HSH6}=((u6M#V%XHwF3PGQpmi?Q5Rg=@BQA`-d?>wAw z@e3oN2`$a0<_Zhp$0>7)_(xo~#$eqieT_dxc4eG<`K2JHTv4@$GC_Ilh6tF=1LrW8 zs+Z2viyhK!t`j&4NZK}XhOf4oL~X{+vwVWDq~*MuQy6}~T&^sg#UGv&itUI*e)AJ< z@};t;ewR|ym)6#@wVuWoKt9_fy(C9N-R`?UhfPCgT=_P@x6M7Y{IoI+dfu66+>2n>+?{yd2 zfM0grs?Qn&Vd>@@x@*XXff+;MKealxym$Oeo%lH~Tn%@c~Me;$+?Wh0}Uke MpZZ_P!u TractResult> { + // Only optimize if all weights are constants + let wg_const = model.node(node.inputs[1].node).op_as::(); + let w1_const = model.node(node.inputs[2].node).op_as::(); + let w2_const = model.node(node.inputs[3].node).op_as::(); + let w3_const = if self.has_w3 { + let c = model.node(node.inputs[4].node).op_as::(); + if c.is_none() { + return Ok(None); + } + c + } else { + None + }; + + if wg_const.is_none() || w1_const.is_none() || w2_const.is_none() { + return Ok(None); + } + + let wg_tensor = wg_const.unwrap().val().clone(); + let w1_tensor = w1_const.unwrap().val().clone(); + let w2_tensor = w2_const.unwrap().val().clone(); + let w3_tensor = w3_const.map(|c| c.val().clone()); + + let dt = model.outlet_fact(node.inputs[0])?.datum_type; + + // Bail if the activation is not supported by the optimized path + if activation_op(&self.activation, self.has_w3).is_none() { + return Ok(None); + } + + let num_experts = w1_tensor.shape()[0]; + let d_model = w1_tensor.shape()[1]; + let d_hidden = w1_tensor.shape()[2]; + + // Build router plan: x [T, D] @ wg.T -> [T, E] + let router_plan = + build_router_plan(&wg_tensor, dt, &model.symbols).context("Building router plan")?; + + // Build per-expert plans + let mut expert_plans = Vec::with_capacity(num_experts); + for eid in 0..num_experts { + let w1_e = w1_tensor.slice(0, eid, eid + 1)?.into_shape(&[d_model, d_hidden])?; + let w2_e = w2_tensor.slice(0, eid, eid + 1)?.into_shape(&[d_hidden, d_model])?; + let w3_e = if let Some(ref w3) = w3_tensor { + Some(w3.slice(0, eid, eid + 1)?.into_shape(&[d_model, d_hidden])?) + } else { + None + }; + + let plan = build_expert_plan( + &w1_e, + &w2_e, + w3_e.as_ref(), + &self.activation, + dt, + &model.symbols, + ) + .with_context(|| format!("Building expert plan for expert {eid}"))?; + expert_plans.push(plan); + } + + let opt_op = OptMoeFfn { + k: self.k, + normalize_gates: self.normalize_gates, + num_experts, + d_model, + d_hidden, + router_plan, + expert_plans, + }; + + let mut patch = TypedModelPatch::default(); + let x_tap = patch.tap_model(model, node.inputs[0])?; + let wires = patch.wire_node(&node.name, opt_op, &[x_tap])?; + patch.shunt_outside(model, OutletId::new(node.id, 0), wires[0])?; + patch.shunt_outside(model, OutletId::new(node.id, 1), wires[1])?; + Ok(Some(patch)) + } + as_op!(); } + +// --------------------------------------------------------------------------- +// Activation helper +// --------------------------------------------------------------------------- + +fn activation_op(name: &str, has_w3: bool) -> Option> { + match name { + "silu" => Some(Box::new(Silu)), + // SwiGLU: the inner activation is silu, w3 provides the gate branch + "swiglu" if has_w3 => Some(Box::new(Silu)), + "gelu" => Some(Box::new(GeluApproximate { fast_impl: false })), + "relu" => Some(Box::new(tract_nnef::tract_core::ops::nn::leaky_relu(0.0))), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Sub-model builders +// --------------------------------------------------------------------------- + +fn build_router_plan( + wg: &Arc, + dt: DatumType, + symbols: &SymbolScope, +) -> TractResult> { + let mut model = TypedModel::default(); + model.symbols = symbols.clone(); + let n_sym = symbols.sym("moe_t"); + + // wg is [E, D] or [1, E, D] — normalize to [E, D] + let wg_2d = if wg.rank() == 3 { + wg.slice(0, 0, 1)?.into_shape(&[wg.shape()[1], wg.shape()[2]])? + } else { + (**wg).clone() + }; + + let d_model = wg_2d.shape()[1]; + let _num_experts = wg_2d.shape()[0]; + + // x: [T, D] + let x = model.add_source("x", dt.fact(&[n_sym.to_dim(), d_model.to_dim()]))?; + // wg: [E, D] as constant + let wg_const = model.add_const("wg", wg_2d)?; + + // router_logits = x @ wg.T -> [T, E] + // EinSum: "ij,kj->ik" means i=T, j=D (contracted), k=E + let axes: AxesMapping = "ij,kj->ik".parse()?; + let logits = model.wire_node("router_logits", EinSum::new(axes, dt), &[x, wg_const])?[0]; + + model.set_output_outlets(&[logits])?; + SimplePlan::new(model.into_optimized()?) +} + +fn build_expert_plan( + w1: &Tensor, + w2: &Tensor, + w3: Option<&Tensor>, + activation: &str, + dt: DatumType, + symbols: &SymbolScope, +) -> TractResult> { + let mut model = TypedModel::default(); + model.symbols = symbols.clone(); + let n_sym = symbols.sym("moe_n"); + + let d_model = w1.shape()[0]; // w1: [D, H] + let _d_hidden = w1.shape()[1]; + + // Input: x_batch [n, D] + let x = model.add_source("x", dt.fact(&[n_sym.to_dim(), d_model.to_dim()]))?; + + // w1 matmul: x_batch [n,D] @ w1 [D,H] -> [n,H] + let w1_const = model.add_const("w1", w1.clone())?; + let axes_mm: AxesMapping = "ij,jk->ik".parse()?; + let h = model.wire_node("w1_matmul", EinSum::new(axes_mm.clone(), dt), &[x, w1_const])?[0]; + + // Activation (caller guarantees activation_op returns Some via codegen check) + let act_op = activation_op(activation, w3.is_some()) + .ok_or_else(|| format_err!("Unsupported activation: {activation}"))?; + let h = model.wire_node("activation", act_op, &[h])?[0]; + + // Optional SwiGLU: gate = x @ w3, h = h * gate + let h = if let Some(w3) = w3 { + let w3_const = model.add_const("w3", w3.clone())?; + let gate = + model.wire_node("w3_matmul", EinSum::new(axes_mm.clone(), dt), &[x, w3_const])?[0]; + model.wire_node("swiglu_mul", mul(), &[h, gate])?[0] + } else { + h + }; + + // w2 matmul: h [n,H] @ w2 [H,D] -> [n,D] + let w2_const = model.add_const("w2", w2.clone())?; + let y = model.wire_node("w2_matmul", EinSum::new(axes_mm, dt), &[h, w2_const])?[0]; + + model.set_output_outlets(&[y])?; + SimplePlan::new(model.into_optimized()?) +} + +// --------------------------------------------------------------------------- +// OptMoeFfn — optimized MoE FFN with pre-compiled expert sub-plans +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug)] +pub struct OptMoeFfn { + pub k: usize, + pub normalize_gates: bool, + pub num_experts: usize, + pub d_model: usize, + pub d_hidden: usize, + pub router_plan: Arc, + pub expert_plans: Vec>, +} + +impl Hash for OptMoeFfn { + fn hash(&self, state: &mut H) { + self.k.hash(state); + self.normalize_gates.hash(state); + self.num_experts.hash(state); + self.d_model.hash(state); + self.d_hidden.hash(state); + } +} + +impl Op for OptMoeFfn { + fn name(&self) -> StaticName { + "OptMoeFfn".to_string().into() + } + op_as_typed_op!(); +} + +impl EvalOp for OptMoeFfn { + fn is_stateless(&self) -> bool { + false + } + + fn state( + &self, + _session: &TurnState, + _node_id: usize, + ) -> TractResult>> { + let router_state = self.router_plan.spawn()?; + let expert_states = self + .expert_plans + .iter() + .map(|p| p.spawn()) + .collect::>>()?; + Ok(Some(Box::new(OptMoeFfnState { + op: self.clone(), + router_state, + expert_states, + }))) + } +} + +// --------------------------------------------------------------------------- +// OptMoeFfnState — pre-spawned plan states, reused across eval calls +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug)] +struct OptMoeFfnState { + op: OptMoeFfn, + router_state: TypedSimpleState, + expert_states: Vec, +} + +#[derive(Clone, Debug)] +struct FrozenOptMoeFfnState { + op: OptMoeFfn, + router_state: TypedFrozenSimpleState, + expert_states: Vec, +} + +impl OpStateFreeze for OptMoeFfnState { + fn freeze(&self) -> Box { + Box::new(FrozenOptMoeFfnState { + op: self.op.clone(), + router_state: self.router_state.freeze(), + expert_states: self.expert_states.iter().map(|s| s.freeze()).collect(), + }) + } +} + +impl FrozenOpState for FrozenOptMoeFfnState { + fn unfreeze(&self) -> Box { + Box::new(OptMoeFfnState { + op: self.op.clone(), + router_state: self.router_state.unfreeze(), + expert_states: self.expert_states.iter().map(|s| s.unfreeze()).collect(), + }) + } +} + +impl OpState for OptMoeFfnState { + fn eval( + &mut self, + _session: &mut TurnState, + _op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + let op = &self.op; + let x_input = &inputs[0]; + let x_view = x_input.to_array_view::()?; + let x_ndim = x_view.ndim(); + let x_orig_shape: Vec = x_view.shape().to_vec(); + + // Normalize x to 2D [T, D] + let x: ArrayView2 = if x_ndim == 3 { + x_view + .into_shape_with_order((x_orig_shape[0] * x_orig_shape[1], x_orig_shape[2]))? + .into_dimensionality()? + } else { + x_view.into_dimensionality()? + }; + + let t_tokens = x.shape()[0]; + let d_model = x.shape()[1]; + let dt = x_input.datum_type(); + + // ---- Step 1: Router via pre-spawned state ---- + let x_2d_tensor = if x_ndim == 3 { + let mut t = Tensor::zero_dt(dt, &[t_tokens, d_model])?; + t.as_slice_mut::()?.copy_from_slice( + x.as_slice().context("x not contiguous for router")?, + ); + t + } else { + (*x_input).clone().into_tensor() + }; + + let router_result = self.router_state.run(tvec![x_2d_tensor.into_tvalue()])?; + let router_logits_tv = &router_result[0]; + let router_logits = router_logits_tv.to_array_view::()?; + let router_logits: ArrayView2 = router_logits.into_dimensionality()?; + + // ---- Step 2: Top-k selection + gate weights ---- + let mut assignments: Vec> = Vec::with_capacity(t_tokens); + for t in 0..t_tokens { + let row = router_logits.row(t); + let mut scores: Vec<(usize, f32)> = + row.iter().enumerate().map(|(e, &s)| (e, s)).collect(); + scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); + scores.truncate(op.k); + + let gate_weights: Vec = if op.normalize_gates { + let max_s = + scores.iter().map(|(_, s)| *s).fold(f32::NEG_INFINITY, f32::max); + let exps: Vec = scores.iter().map(|(_, s)| (s - max_s).exp()).collect(); + let sum: f32 = exps.iter().sum(); + exps.iter().map(|e| e / sum).collect() + } else { + scores.iter().map(|(_, s)| *s).collect() + }; + + assignments.push( + scores + .iter() + .zip(gate_weights) + .map(|((eid, _), gw)| (*eid, gw)) + .collect(), + ); + } + + // ---- Step 3: Group tokens per expert ---- + let mut expert_tokens: Vec> = vec![Vec::new(); op.num_experts]; + for (t, token_experts) in assignments.iter().enumerate() { + for &(eid, gw) in token_experts { + expert_tokens[eid].push((t, gw)); + } + } + + // ---- Step 4: Per-expert computation via pre-spawned states ---- + let mut output = Array2::::zeros((t_tokens, d_model)); + + for eid in 0..op.num_experts { + let tokens = &expert_tokens[eid]; + if tokens.is_empty() { + continue; + } + let n = tokens.len(); + + // Gather: build x_batch [n, D] + let mut x_batch = Tensor::zero_dt(dt, &[n, d_model])?; + { + let x_batch_slice = x_batch.as_slice_mut::()?; + for (i, &(t, _)) in tokens.iter().enumerate() { + let src = x.row(t); + x_batch_slice[i * d_model..(i + 1) * d_model] + .copy_from_slice(src.as_slice().unwrap()); + } + } + + // Run expert plan (reusing pre-spawned state) + let y_expert = self.expert_states[eid].run(tvec![x_batch.into_tvalue()])?; + + // Scatter-add weighted results + let y_view = y_expert[0].to_array_view::()?; + let y_view: ArrayView2 = y_view.into_dimensionality()?; + for (i, &(t, gw)) in tokens.iter().enumerate() { + let y_row = y_view.row(i); + let mut out_row = output.row_mut(t); + out_row.scaled_add(gw, &y_row); + } + } + + // ---- Restore shapes ---- + let output_tensor = if x_ndim == 3 { + output + .into_shape_with_order((x_orig_shape[0], x_orig_shape[1], d_model))? + .into_tensor() + } else { + output.into_tensor() + }; + let router_tensor = if x_ndim == 3 { + let rl = router_logits_tv.clone().into_tensor(); + rl.into_shape(&[x_orig_shape[0], x_orig_shape[1], op.num_experts])? + } else { + router_logits_tv.clone().into_tensor() + }; + + Ok(tvec![output_tensor.into_tvalue(), router_tensor.into_tvalue()]) + } +} + +impl TypedOp for OptMoeFfn { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let x_fact = inputs[0]; + let output_fact = x_fact.datum_type.fact(x_fact.shape.clone()); + + let mut router_shape: TVec = x_fact.shape.iter().cloned().collect(); + if let Some(last) = router_shape.last_mut() { + *last = self.num_experts.to_dim(); + } + let router_fact = x_fact.datum_type.fact(router_shape); + + Ok(tvec!(output_fact, router_fact)) + } + + as_op!(); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_moe_model( + t_tokens: usize, + d_model: usize, + d_hidden: usize, + num_experts: usize, + k: usize, + has_w3: bool, + ) -> TractResult<(TypedModel, Tensor)> { + let mut model = TypedModel::default(); + + let x = model.add_source("x", f32::datum_type().fact(&[t_tokens, d_model]))?; + + // Deterministic pseudo-random weights + let mut rng_state: u64 = 42; + let mut next_f32 = || -> f32 { + rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1); + ((rng_state >> 33) as f32 / (1u64 << 31) as f32) - 1.0 + }; + + let make_tensor = |shape: &[usize], rng: &mut dyn FnMut() -> f32| -> Tensor { + let n: usize = shape.iter().product(); + let data: Vec = (0..n).map(|_| rng()).collect(); + tract_ndarray::ArrayD::from_shape_vec(shape, data).unwrap().into_tensor() + }; + + let wg_data = make_tensor(&[num_experts, d_model], &mut next_f32); + let w1_data = make_tensor(&[num_experts, d_model, d_hidden], &mut next_f32); + let w2_data = make_tensor(&[num_experts, d_hidden, d_model], &mut next_f32); + + let wg = model.add_const("wg", wg_data)?; + let w1 = model.add_const("w1", w1_data)?; + let w2 = model.add_const("w2", w2_data)?; + + let mut inputs = vec![x, wg, w1, w2]; + + if has_w3 { + let w3_data = make_tensor(&[num_experts, d_model, d_hidden], &mut next_f32); + let w3 = model.add_const("w3", w3_data)?; + inputs.push(w3); + } + + let op = MoeFfn { + k, + activation: "silu".to_string(), + normalize_gates: true, + has_w3, + }; + let outputs = model.wire_node("moe", op, &inputs)?; + model.set_output_outlets(&outputs)?; + + // Create input tensor + let x_data = make_tensor(&[t_tokens, d_model], &mut next_f32); + + Ok((model, x_data)) + } + + #[test] + fn test_opt_moe_ffn_matches_reference() -> TractResult<()> { + // Test with SwiGLU (has_w3=true) + let (model, x_data) = make_moe_model(8, 16, 32, 4, 2, true)?; + + // Run reference (unoptimized) + let ref_plan = SimplePlan::new(model.clone())?; + let ref_result = ref_plan.spawn()?.run(tvec![x_data.clone().into_tvalue()])?; + + // Run optimized + let opt_model = model.into_optimized()?; + + // Verify MoeFfn was replaced with OptMoeFfn + let has_opt = opt_model.nodes().iter().any(|n| n.op_is::()); + assert!(has_opt, "Expected OptMoeFfn in optimized model"); + + let opt_plan = SimplePlan::new(opt_model)?; + let opt_result = opt_plan.spawn()?.run(tvec![x_data.into_tvalue()])?; + + // Compare outputs + ref_result[0].close_enough(&opt_result[0], Approximation::Approximate)?; + ref_result[1].close_enough(&opt_result[1], Approximation::Approximate)?; + + Ok(()) + } + + #[test] + fn test_opt_moe_ffn_no_w3() -> TractResult<()> { + // Test without SwiGLU (has_w3=false) + let (model, x_data) = make_moe_model(8, 16, 32, 4, 2, false)?; + + let ref_plan = SimplePlan::new(model.clone())?; + let ref_result = ref_plan.spawn()?.run(tvec![x_data.clone().into_tvalue()])?; + + let opt_model = model.into_optimized()?; + let opt_plan = SimplePlan::new(opt_model)?; + let opt_result = opt_plan.spawn()?.run(tvec![x_data.into_tvalue()])?; + + ref_result[0].close_enough(&opt_result[0], Approximation::Approximate)?; + ref_result[1].close_enough(&opt_result[1], Approximation::Approximate)?; + + Ok(()) + } + + #[test] + fn test_opt_moe_ffn_top1() -> TractResult<()> { + let (model, x_data) = make_moe_model(16, 8, 16, 8, 1, true)?; + + let ref_plan = SimplePlan::new(model.clone())?; + let ref_result = ref_plan.spawn()?.run(tvec![x_data.clone().into_tvalue()])?; + + let opt_model = model.into_optimized()?; + let opt_plan = SimplePlan::new(opt_model)?; + let opt_result = opt_plan.spawn()?.run(tvec![x_data.into_tvalue()])?; + + ref_result[0].close_enough(&opt_result[0], Approximation::Approximate)?; + ref_result[1].close_enough(&opt_result[1], Approximation::Approximate)?; + + Ok(()) + } + + #[test] + fn test_codegen_fallback_on_non_const_weights() -> TractResult<()> { + // When weights are inputs (not constants), codegen should not fire + let mut model = TypedModel::default(); + let x = model.add_source("x", f32::datum_type().fact(&[4, 8]))?; + let wg = model.add_source("wg", f32::datum_type().fact(&[2, 8]))?; + let w1 = model.add_source("w1", f32::datum_type().fact(&[2, 8, 16]))?; + let w2 = model.add_source("w2", f32::datum_type().fact(&[2, 16, 8]))?; + + let op = MoeFfn { + k: 1, + activation: "silu".to_string(), + normalize_gates: true, + has_w3: false, + }; + let outputs = model.wire_node("moe", op, &[x, wg, w1, w2])?; + model.set_output_outlets(&outputs)?; + + let opt_model = model.into_optimized()?; + + // Should still have MoeFfn (not OptMoeFfn) + let has_moe = opt_model.nodes().iter().any(|n| n.op_is::()); + assert!(has_moe, "Expected MoeFfn to remain when weights are not constants"); + + Ok(()) + } + + #[test] + fn test_e2e_nnef_qwen3_moe() -> TractResult<()> { + use crate::WithTractTransformers; + use std::io::Cursor; + + // Load the Qwen3 MoE model exported from transformers + let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../harness/nnef-test-cases/moe-ffn/qwen3-tiny"); + + let nnef = tract_nnef::nnef().with_tract_transformers(); + let model = nnef.model_for_path(&model_path)?; + let model = model.into_optimized()?; + + // Verify OptMoeFfn is present after optimization + let has_opt = model.nodes().iter().any(|n: &TypedNode| n.op_is::()); + assert!(has_opt, "Expected OptMoeFfn in optimized model"); + + let plan = SimplePlan::new(model)?; + + // Load input and expected output from io.npz + let npz_path = model_path.join("io.npz"); + let npz_bytes = std::fs::read(&npz_path)?; + let mut npz = ndarray_npy::NpzReader::new(Cursor::new(npz_bytes))?; + + let input: tract_ndarray::ArrayD = npz.by_name("input_0.npy")?; + let expected_output: tract_ndarray::ArrayD = npz.by_name("output_0.npy")?; + + // Run inference + let result = plan.spawn()?.run(tvec![input.into_tensor().into_tvalue()])?; + + // Compare against PyTorch reference output + result[0].close_enough(&expected_output.into_tensor(), Approximation::Approximate)?; + + Ok(()) + } +}