Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trtx-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ better_enum!(AttentionNormalizationOp);
better_enum!(SeekPosition);
better_enum!(WeightsRole);
better_enum!(TripLimit);
#[cfg(feature = "v_1_4")]
better_enum!(MoEActType);
#[cfg(feature = "v_1_4")]
better_enum!(CollectiveOperation);

pub use enums::ErrorCode;

use autocxx::prelude::*;
Expand Down
276 changes: 276 additions & 0 deletions trtx/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::pin::Pin;
use trtx_sys::nvinfer1::{IConcatenationLayer, INetworkDefinition, ITensor};
#[cfg(feature = "v_1_4")]
use trtx_sys::CollectiveOperation;
#[cfg(feature = "v_1_4")]
use trtx_sys::MoEActType;
use trtx_sys::ReduceOperation;
use trtx_sys::{nvinfer1, LayerType, Weights};
use trtx_sys::{AsLayer, AsLayerTyped};
use trtx_sys::{DataType, MatrixOperation, ScaleMode, TopKOperation};
Expand Down Expand Up @@ -579,6 +584,227 @@ impl NormalizationLayer<'_> {
}
}

#[cfg(feature = "v_1_4")]
impl MoELayer<'_> {
/// See [`trtx_sys::nvinfer1::IMoELayer::setGatedWeights`].
pub fn set_gated_weights(
&mut self,
network: &mut NetworkDefinition,
fc_gate_weights: &Tensor,
fc_up_weights: &Tensor,
fc_down_weights: &Tensor,
activation_type: MoEActType,
) -> Result<()> {
crate::check_network!(network, self);
crate::check_network!(network, fc_gate_weights);
crate::check_network!(network, fc_up_weights);
crate::check_network!(network, fc_down_weights);
self.inner.as_mut().setGatedWeights(
fc_gate_weights.pin_mut(),
fc_up_weights.pin_mut(),
fc_down_weights.pin_mut(),
activation_type.into(),
);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setGatedBiases`].
pub fn set_gated_biases(
&mut self,
network: &mut NetworkDefinition,
fc_gate_biases: &Tensor,
fc_up_biases: &Tensor,
fc_down_biases: &Tensor,
) -> Result<()> {
crate::check_network!(network, self);
crate::check_network!(network, fc_gate_biases);
crate::check_network!(network, fc_up_biases);
crate::check_network!(network, fc_down_biases);
self.inner.as_mut().setGatedBiases(
fc_gate_biases.pin_mut(),
fc_up_biases.pin_mut(),
fc_down_biases.pin_mut(),
);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setActivationType`].
pub fn set_activation_type(
&mut self,
network: &mut NetworkDefinition,
activation_type: MoEActType,
) -> Result<()> {
crate::check_network!(network, self);
self.inner
.as_mut()
.setActivationType(activation_type.into());
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getActivationType`].
pub fn activation_type(&self, network: &NetworkDefinition) -> MoEActType {
crate::check_network!(network, self);
self.inner.as_ref().getActivationType().into()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationStatic`].
pub fn set_quantization_static(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_scale: &Tensor,
data_type: DataType,
) -> Result<()> {
crate::check_network!(network, self);
crate::check_network!(network, fc_down_activation_scale);
self.inner
.as_mut()
.setQuantizationStatic(fc_down_activation_scale.pin_mut(), data_type.into());
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationDynamicDblQ`].
pub fn set_quantization_dynamic_dbl_q(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_dbl_q_scale: &Tensor,
data_type: DataType,
block_shape: &[i64],
dyn_q_output_scale_type: DataType,
) -> Result<()> {
crate::check_network!(network, self);
crate::check_network!(network, fc_down_activation_dbl_q_scale);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationDynamicDblQ(
fc_down_activation_dbl_q_scale.pin_mut(),
data_type.into(),
&block,
dyn_q_output_scale_type.into(),
);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationToType`].
pub fn set_quantization_to_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setQuantizationToType(type_.into());
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getQuantizationToType`].
pub fn quantization_to_type(&self, network: &NetworkDefinition) -> DataType {
crate::check_network!(network, self);
self.inner.as_ref().getQuantizationToType().into()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setQuantizationBlockShape`].
pub fn set_quantization_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
crate::check_network!(network, self);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationBlockShape(&block);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getQuantizationBlockShape`].
pub fn quantization_block_shape(&self, network: &NetworkDefinition) -> Vec<i64> {
crate::check_network!(network, self);
let d = self.inner.as_ref().getQuantizationBlockShape();
d.d[..d.nbDims as usize].to_vec()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setDynQOutputScaleType`].
pub fn set_dyn_q_output_scale_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setDynQOutputScaleType(type_.into());
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getDynQOutputScaleType`].
pub fn dyn_q_output_scale_type(&self, network: &NetworkDefinition) -> DataType {
crate::check_network!(network, self);
self.inner.as_ref().getDynQOutputScaleType().into()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParams`].
pub fn set_swiglu_params(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
alpha: f32,
beta: f32,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setSwigluParams(limit, alpha, beta);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamLimit`].
pub fn set_swiglu_param_limit(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setSwigluParamLimit(limit);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamLimit`].
pub fn swiglu_param_limit(&self, network: &NetworkDefinition) -> f32 {
crate::check_network!(network, self);
self.inner.as_ref().getSwigluParamLimit()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamAlpha`].
pub fn set_swiglu_param_alpha(
&mut self,
network: &mut NetworkDefinition,
alpha: f32,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setSwigluParamAlpha(alpha);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamAlpha`].
pub fn swiglu_param_alpha(&self, network: &NetworkDefinition) -> f32 {
crate::check_network!(network, self);
self.inner.as_ref().getSwigluParamAlpha()
}

/// See [`trtx_sys::nvinfer1::IMoELayer::setSwigluParamBeta`].
pub fn set_swiglu_param_beta(
&mut self,
network: &mut NetworkDefinition,
beta: f32,
) -> Result<()> {
crate::check_network!(network, self);
self.inner.as_mut().setSwigluParamBeta(beta);
Ok(())
}

/// See [`trtx_sys::nvinfer1::IMoELayer::getSwigluParamBeta`].
pub fn swiglu_param_beta(&self, network: &NetworkDefinition) -> f32 {
crate::check_network!(network, self);
self.inner.as_ref().getSwigluParamBeta()
}
}

/// `IDistCollectiveLayer` adds no methods beyond [`ILayer`](trtx_sys::nvinfer1::ILayer); use [`Layer`] helpers.
#[cfg(feature = "v_1_4")]
impl DistCollectiveLayer<'_> {}

impl Tensor<'_> {
pub fn name(&self, network: &NetworkDefinition) -> Result<String> {
crate::check_network!(network, self);
Expand Down Expand Up @@ -1514,6 +1740,56 @@ impl<'network> NetworkDefinition<'network> {
network: self.inner.as_ptr(),
})
}

/// See [`trtx_sys::nvinfer1::INetworkDefinition::addMoE`].
#[cfg(feature = "v_1_4")]
pub fn add_moe(
&mut self,
hidden_states: &Tensor,
selected_experts_for_tokens: &Tensor,
scores_for_selected_experts: &Tensor,
) -> Result<MoELayer<'network>> {
crate::check_network!(self, hidden_states);
crate::check_network!(self, selected_experts_for_tokens);
crate::check_network!(self, scores_for_selected_experts);
let layer_ptr = self.inner.pin_mut().addMoE(
hidden_states.pin_mut(),
selected_experts_for_tokens.pin_mut(),
scores_for_selected_experts.pin_mut(),
);
MoELayer::new(self.inner.as_ptr(), layer_ptr)
}

/// See [`trtx_sys::nvinfer1::INetworkDefinition::addDistCollective`].
///
/// Pass an empty `groups` slice so all ranks participate (`groups == nullptr`, `groupSize == 0` in C++).
#[cfg(feature = "v_1_4")]
pub fn add_dist_collective(
&mut self,
input: &Tensor,
dist_collective_op: CollectiveOperation,
reduce_op: ReduceOperation,
root: i64,
groups: &[i64],
) -> Result<DistCollectiveLayer<'network>> {
crate::check_network!(self, input);
let (groups_ptr, group_size) = if groups.is_empty() {
(std::ptr::null_mut(), 0i64)
} else {
(groups.as_ptr() as *mut i64, groups.len() as i64)
};
let layer_ptr = unsafe {
self.inner.pin_mut().addDistCollective(
input.pin_mut(),
dist_collective_op.into(),
reduce_op.into(),
root,
groups_ptr,
group_size,
)
};
DistCollectiveLayer::new(self.inner.as_ptr(), layer_ptr)
}
}

// --- Attention: get_output ---
Expand Down
Loading