From 575e6c3ef92b7fa512e22f3c2975988ed73c380b Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Thu, 2 Apr 2026 10:12:26 +0200 Subject: [PATCH] feat: add support for v_1_4 layers (IMoELayer, IDistCollectiveLayer) --- trtx-sys/src/lib.rs | 5 + trtx/src/network.rs | 276 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 281 insertions(+) diff --git a/trtx-sys/src/lib.rs b/trtx-sys/src/lib.rs index cbc1fbe..cbff218 100644 --- a/trtx-sys/src/lib.rs +++ b/trtx-sys/src/lib.rs @@ -89,6 +89,11 @@ better_enum!(AttentionNormalizationOp); better_enum!(SeekPosition); better_enum!(WeightsRole); better_enum!(TripLimit); +#[cfg(feature = "v_1_4")] +better_enum!(MoEActType); +#[cfg(feature = "v_1_4")] +better_enum!(CollectiveOperation); + pub use enums::ErrorCode; use autocxx::prelude::*; diff --git a/trtx/src/network.rs b/trtx/src/network.rs index ff0692d..4c89a2d 100644 --- a/trtx/src/network.rs +++ b/trtx/src/network.rs @@ -10,6 +10,11 @@ use std::ffi::{CStr, CString}; use std::marker::PhantomData; use std::pin::Pin; use trtx_sys::nvinfer1::{IConcatenationLayer, INetworkDefinition, ITensor}; +#[cfg(feature = "v_1_4")] +use trtx_sys::CollectiveOperation; +#[cfg(feature = "v_1_4")] +use trtx_sys::MoEActType; +use trtx_sys::ReduceOperation; use trtx_sys::{nvinfer1, LayerType, Weights}; use trtx_sys::{AsLayer, AsLayerTyped}; use trtx_sys::{DataType, MatrixOperation, ScaleMode, TopKOperation}; @@ -579,6 +584,227 @@ impl NormalizationLayer<'_> { } } +#[cfg(feature = "v_1_4")] +impl MoELayer<'_> { + /// See [`trtx_sys::nvinfer1::IMoELayer::setGatedWeights`]. + pub fn set_gated_weights( + &mut self, + network: &mut NetworkDefinition, + fc_gate_weights: &Tensor, + fc_up_weights: &Tensor, + fc_down_weights: &Tensor, + activation_type: MoEActType, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, fc_gate_weights); + crate::check_network!(network, fc_up_weights); + crate::check_network!(network, fc_down_weights); + self.inner.as_mut().setGatedWeights( + fc_gate_weights.pin_mut(), + fc_up_weights.pin_mut(), + fc_down_weights.pin_mut(), + activation_type.into(), + ); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setGatedBiases`]. + pub fn set_gated_biases( + &mut self, + network: &mut NetworkDefinition, + fc_gate_biases: &Tensor, + fc_up_biases: &Tensor, + fc_down_biases: &Tensor, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, fc_gate_biases); + crate::check_network!(network, fc_up_biases); + crate::check_network!(network, fc_down_biases); + self.inner.as_mut().setGatedBiases( + fc_gate_biases.pin_mut(), + fc_up_biases.pin_mut(), + fc_down_biases.pin_mut(), + ); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setActivationType`]. + pub fn set_activation_type( + &mut self, + network: &mut NetworkDefinition, + activation_type: MoEActType, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner + .as_mut() + .setActivationType(activation_type.into()); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getActivationType`]. + pub fn activation_type(&self, network: &NetworkDefinition) -> MoEActType { + crate::check_network!(network, self); + self.inner.as_ref().getActivationType().into() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationStatic`]. + pub fn set_quantization_static( + &mut self, + network: &mut NetworkDefinition, + fc_down_activation_scale: &Tensor, + data_type: DataType, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, fc_down_activation_scale); + self.inner + .as_mut() + .setQuantizationStatic(fc_down_activation_scale.pin_mut(), data_type.into()); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationDynamicDblQ`]. + pub fn set_quantization_dynamic_dbl_q( + &mut self, + network: &mut NetworkDefinition, + fc_down_activation_dbl_q_scale: &Tensor, + data_type: DataType, + block_shape: &[i64], + dyn_q_output_scale_type: DataType, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, fc_down_activation_dbl_q_scale); + let block = trtx_sys::Dims::from_slice(block_shape); + self.inner.as_mut().setQuantizationDynamicDblQ( + fc_down_activation_dbl_q_scale.pin_mut(), + data_type.into(), + &block, + dyn_q_output_scale_type.into(), + ); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationToType`]. + pub fn set_quantization_to_type( + &mut self, + network: &mut NetworkDefinition, + type_: DataType, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setQuantizationToType(type_.into()); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getQuantizationToType`]. + pub fn quantization_to_type(&self, network: &NetworkDefinition) -> DataType { + crate::check_network!(network, self); + self.inner.as_ref().getQuantizationToType().into() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationBlockShape`]. + pub fn set_quantization_block_shape( + &mut self, + network: &mut NetworkDefinition, + block_shape: &[i64], + ) -> Result<()> { + crate::check_network!(network, self); + let block = trtx_sys::Dims::from_slice(block_shape); + self.inner.as_mut().setQuantizationBlockShape(&block); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getQuantizationBlockShape`]. + pub fn quantization_block_shape(&self, network: &NetworkDefinition) -> Vec { + crate::check_network!(network, self); + let d = self.inner.as_ref().getQuantizationBlockShape(); + d.d[..d.nbDims as usize].to_vec() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setDynQOutputScaleType`]. + pub fn set_dyn_q_output_scale_type( + &mut self, + network: &mut NetworkDefinition, + type_: DataType, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setDynQOutputScaleType(type_.into()); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getDynQOutputScaleType`]. + pub fn dyn_q_output_scale_type(&self, network: &NetworkDefinition) -> DataType { + crate::check_network!(network, self); + self.inner.as_ref().getDynQOutputScaleType().into() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParams`]. + pub fn set_swiglu_params( + &mut self, + network: &mut NetworkDefinition, + limit: f32, + alpha: f32, + beta: f32, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setSwigluParams(limit, alpha, beta); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamLimit`]. + pub fn set_swiglu_param_limit( + &mut self, + network: &mut NetworkDefinition, + limit: f32, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setSwigluParamLimit(limit); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamLimit`]. + pub fn swiglu_param_limit(&self, network: &NetworkDefinition) -> f32 { + crate::check_network!(network, self); + self.inner.as_ref().getSwigluParamLimit() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamAlpha`]. + pub fn set_swiglu_param_alpha( + &mut self, + network: &mut NetworkDefinition, + alpha: f32, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setSwigluParamAlpha(alpha); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamAlpha`]. + pub fn swiglu_param_alpha(&self, network: &NetworkDefinition) -> f32 { + crate::check_network!(network, self); + self.inner.as_ref().getSwigluParamAlpha() + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamBeta`]. + pub fn set_swiglu_param_beta( + &mut self, + network: &mut NetworkDefinition, + beta: f32, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setSwigluParamBeta(beta); + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamBeta`]. + pub fn swiglu_param_beta(&self, network: &NetworkDefinition) -> f32 { + crate::check_network!(network, self); + self.inner.as_ref().getSwigluParamBeta() + } +} + +/// `IDistCollectiveLayer` adds no methods beyond [`ILayer`](trtx_sys::nvinfer1::ILayer); use [`Layer`] helpers. +#[cfg(feature = "v_1_4")] +impl DistCollectiveLayer<'_> {} + impl Tensor<'_> { pub fn name(&self, network: &NetworkDefinition) -> Result { crate::check_network!(network, self); @@ -1514,6 +1740,56 @@ impl<'network> NetworkDefinition<'network> { network: self.inner.as_ptr(), }) } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addMoE`]. + #[cfg(feature = "v_1_4")] + pub fn add_moe( + &mut self, + hidden_states: &Tensor, + selected_experts_for_tokens: &Tensor, + scores_for_selected_experts: &Tensor, + ) -> Result> { + crate::check_network!(self, hidden_states); + crate::check_network!(self, selected_experts_for_tokens); + crate::check_network!(self, scores_for_selected_experts); + let layer_ptr = self.inner.pin_mut().addMoE( + hidden_states.pin_mut(), + selected_experts_for_tokens.pin_mut(), + scores_for_selected_experts.pin_mut(), + ); + MoELayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addDistCollective`]. + /// + /// Pass an empty `groups` slice so all ranks participate (`groups == nullptr`, `groupSize == 0` in C++). + #[cfg(feature = "v_1_4")] + pub fn add_dist_collective( + &mut self, + input: &Tensor, + dist_collective_op: CollectiveOperation, + reduce_op: ReduceOperation, + root: i64, + groups: &[i64], + ) -> Result> { + crate::check_network!(self, input); + let (groups_ptr, group_size) = if groups.is_empty() { + (std::ptr::null_mut(), 0i64) + } else { + (groups.as_ptr() as *mut i64, groups.len() as i64) + }; + let layer_ptr = unsafe { + self.inner.pin_mut().addDistCollective( + input.pin_mut(), + dist_collective_op.into(), + reduce_op.into(), + root, + groups_ptr, + group_size, + ) + }; + DistCollectiveLayer::new(self.inner.as_ptr(), layer_ptr) + } } // --- Attention: get_output ---