Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/crates.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ jobs:
if [ "$FULL" == "true" ]
then
echo 'os=["ubuntu-latest", "macos-latest"]' >> $GITHUB_OUTPUT
echo 'rust=["1.89.0", "stable", "beta", "nightly"]' >> $GITHUB_OUTPUT
echo 'rust=["1.91.0", "stable", "beta", "nightly"]' >> $GITHUB_OUTPUT
else
echo ::notice::Skipping macOS checks on PR and commit. Dispatch workflow manually if needed.
echo 'os=["ubuntu-latest"]' >> $GITHUB_OUTPUT
echo 'rust=["1.89.0"]' >> $GITHUB_OUTPUT
echo 'rust=["1.91.0"]' >> $GITHUB_OUTPUT
fi

crates:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cross-platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
env:
CARGO_INCREMENTAL: false
FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true
RUSTUP_TOOLCHAIN: 1.89.0
RUSTUP_TOOLCHAIN: 1.91.0

jobs:
linux:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
env:
CARGO_INCREMENTAL: false
FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true
RUSTUP_TOOLCHAIN: 1.89.0
RUSTUP_TOOLCHAIN: 1.91.0

jobs:
examples:
Expand Down
2 changes: 1 addition & 1 deletion .travis/ci-system-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -e

if [ -z "$RUSTUP_TOOLCHAIN" ]
then
export RUSTUP_TOOLCHAIN=1.89.0
export RUSTUP_TOOLCHAIN=1.91.0
fi

export RUSTUP_TOOLCHAIN
Expand Down
2 changes: 1 addition & 1 deletion .travis/native.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -ex

if [ -z "$RUSTUP_TOOLCHAIN" ]
then
export RUSTUP_TOOLCHAIN=1.89.0
export RUSTUP_TOOLCHAIN=1.91.0
fi

rustup update
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Unreleased

* [Breaking][MSRV] MSRV bumped to 1.91.0 (for `const TypeId::of`).

# 0.23.0-dev.3 — 2026-03-20

### Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ default-members = [
]

[workspace.package]
rust-version = "1.89"
rust-version = "1.91"

[workspace.dependencies]
accelerate-src = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
![tract-logo](assets/tract-logo/PNG/tract-horizontal-blue.png)

![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white)
![rustc >= 1.89.0](https://img.shields.io/badge/rustc-%3E%3D1.89.0-brightgreen)
![rustc >= 1.91.0](https://img.shields.io/badge/rustc-%3E%3D1.91.0-brightgreen)
![MIT/Apache 2](https://img.shields.io/crates/l/tract)
[![Native Linux test status](https://github.com/snipsco/tract/workflows/Native%20Linux/badge.svg)](https://github.com/snipsco/tract/actions)
[![Embedded targets status](https://github.com/snipsco/tract/workflows/Embedded%20targets/badge.svg)](https://github.com/snipsco/tract/actions)
Expand Down
21 changes: 21 additions & 0 deletions cuda/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,27 @@ impl DeviceContext for TractCudaContext {
) -> TractResult<Box<dyn OwnedDeviceTensor>> {
Ok(Box::new(CudaTensor::uninitialized_exotic(exotic_fact)?))
}

fn copy_nd(
&self,
input: &DeviceTensor,
input_offset: usize,
input_strides: &[isize],
output: &DeviceTensor,
output_offset: usize,
output_shape: &[usize],
output_strides: &[isize],
) -> TractResult<()> {
crate::kernels::array::cuda_copy_nd_dispatch(
input,
input_offset,
input_strides,
output,
output_offset,
output_shape,
output_strides,
)
}
}

/// A recorded GPU kernel timing entry: start/end events tagged with a node_id.
Expand Down
8 changes: 8 additions & 0 deletions cuda/src/kernels/array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ impl Cast {
}
}

pub fn cuda_cast_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> {
crate::with_cuda_stream(|stream| Cast.dispatch_eval(stream, input, output))
}

crate::register_cuda_op!(tract_core::ops::cast::Cast, |_source, _node, op| {
Ok(crate::transform::cuda_cast_new(op.to).map(|c| Box::new(c) as _))
});

#[cfg(test)]
mod tests {

Expand Down
2 changes: 2 additions & 0 deletions cuda/src/kernels/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ mod dispatch;
mod rotate_half;

pub use cast::Cast;
pub use cast::cuda_cast_dispatch;
pub use copy::Memcpy;
pub use dispatch::cuda_copy_nd_dispatch;
pub use rotate_half::RotateHalf;
pub use rotate_half::cuda_rotate_half_dispatch;

pub fn all_functions() -> Vec<String> {
use std::collections::HashSet;
Expand Down
12 changes: 12 additions & 0 deletions cuda/src/kernels/array/rotate_half.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ impl RotateHalf {
}
}

pub fn cuda_rotate_half_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> {
crate::with_cuda_stream(|stream| RotateHalf.dispatch_eval(stream, input, output))
}

crate::register_cuda_op!(tract_transformers::ops::apply_rope::RotateHalf, |source, node, _op| {
rule_if!(RotateHalf::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(tract_gpu::ops::rotate_half::GpuRotateHalf::new(
"Cuda",
cuda_rotate_half_dispatch,
))))
});

#[cfg(test)]
mod tests {

Expand Down
13 changes: 13 additions & 0 deletions cuda/src/kernels/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ pub fn cuda_bin_op_dispatch(
crate::with_cuda_stream(|stream| dispatch_eval(stream, mini_op, lhs, rhs, output))
}

pub fn cuda_bin_op(mini_op: Box<dyn BinMiniOp>) -> tract_gpu::ops::binary::GpuBinOp {
tract_gpu::ops::binary::GpuBinOp {
backend_name: "Cuda",
mini_op,
dispatch: cuda_bin_op_dispatch,
}
}

crate::register_cuda_op!(tract_core::ops::binary::TypedBinOp, |source, node, op| {
rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(cuda_bin_op(op.0.clone()))))
});

#[cfg(test)]
mod tests {
use tract_gpu::tensor::IntoDevice;
Expand Down
16 changes: 16 additions & 0 deletions cuda/src/kernels/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ pub fn cuda_element_wise_dispatch(
crate::with_cuda_stream(|stream| dispatch_eval(stream, mini_op, input, output))
}

pub fn cuda_element_wise_op(
mini_op: Box<dyn ElementWiseMiniOp>,
) -> tract_gpu::ops::element_wise::GpuElementWise {
tract_gpu::ops::element_wise::GpuElementWise {
backend_name: "Cuda",
mini_op,
dispatch: cuda_element_wise_dispatch,
}
}

// Generic element-wise fallback — checked after LeakyRelu, GeluApproximate.
crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| {
rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(cuda_element_wise_op(op.0.clone()))))
});

#[cfg(test)]
mod tests {
use super::*;
Expand Down
52 changes: 51 additions & 1 deletion cuda/src/kernels/iff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl Iff {
}

pub fn kernel_name(&self, dt: DatumType, variant: &str) -> TractResult<String> {
Ok(format!("iff_{variant}_{}", DeviceTensor::tname(dt)?))
Ok(format!("iff_{variant}_{}", tract_gpu::utils::BroadcastKind::copy_tname(dt)))
}

pub fn eval(
Expand Down Expand Up @@ -124,3 +124,53 @@ impl Iff {
Ok(())
}
}

pub fn cuda_iff_dispatch(
cond: &DeviceTensor,
then_value: &DeviceTensor,
else_value: &DeviceTensor,
cond_strides: &[isize],
then_strides: &[isize],
else_strides: &[isize],
output: &DeviceTensor,
output_shape: &[usize],
output_strides: &[isize],
) -> TractResult<()> {
crate::with_cuda_stream(|stream| {
let total_elems: usize = output_shape.iter().product();
let block_dim = (128_u32, 1, 1);
let grid_dim = (total_elems.div_ceil(block_dim.0 as usize) as u32, 1, 1);

let kernel_name = format!(
"iff_generic_{}",
tract_gpu::utils::BroadcastKind::copy_tname(output.datum_type())
);
let func = cuda_context().load_pipeline(LibraryName::Binary, kernel_name)?;
let cfg = LaunchConfig { grid_dim, block_dim, shared_mem_bytes: 0 };

let cond_view = get_cuda_view(cond);
let then_view = get_cuda_view(then_value);
let else_view = get_cuda_view(else_value);
let o_view = get_cuda_view(output);

let mut launch_args = TractLaunchArgs::new(stream, &func);
launch_args.push_view(&cond_view);
launch_args.push_view(&then_view);
launch_args.push_view(&else_view);
launch_args.push_view(&o_view);
launch_args.push_slice_i32(output_shape);
launch_args.push_slice_i32(cond_strides);
launch_args.push_slice_i32(then_strides);
launch_args.push_slice_i32(else_strides);
launch_args.push_slice_i32(output_strides);

launch_args.launch(cfg)
})
}

crate::register_cuda_op!(tract_core::ops::logic::Iff, |_source, _node, _op| {
Ok(Some(Box::new(tract_gpu::ops::iff::GpuIff {
backend_name: "Cuda",
dispatch: cuda_iff_dispatch,
})))
});
17 changes: 17 additions & 0 deletions cuda/src/kernels/nn/apply_rope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,23 @@ impl ApplyRope {
}
}

pub fn cuda_apply_rope_dispatch(
input: &DeviceTensor,
cos: &DeviceTensor,
sin: &DeviceTensor,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_cuda_stream(|stream| ApplyRope.dispatch_eval(stream, input, cos, sin, output))
}

crate::register_cuda_op!(tract_transformers::ops::apply_rope::ApplyRope, |source, node, _op| {
rule_if!(ApplyRope::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(tract_gpu::ops::apply_rope::GpuApplyRope {
backend_name: "Cuda",
dispatch: cuda_apply_rope_dispatch,
})))
});

#[cfg(test)]
mod tests {
use std::f32::consts::PI;
Expand Down
23 changes: 23 additions & 0 deletions cuda/src/kernels/nn/gelu_approximate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@ impl GeluApproximate {
}
}

pub fn cuda_gelu_approximate_dispatch(
fast_impl: bool,
input: &DeviceTensor,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_cuda_stream(|stream| {
GeluApproximate { fast_impl }.dispatch_eval(stream, input, output)
})
}

// GeluApproximate is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId.
crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| {
rule_if_some!(
ew = op.0.downcast_ref::<tract_transformers::ops::gelu_approximate::GeluApproximate>()
);
rule_if!(GeluApproximate::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(tract_gpu::ops::gelu_approximate::GpuGeluApproximate::new(
ew.fast_impl,
"Cuda",
cuda_gelu_approximate_dispatch,
))))
});

#[cfg(test)]
mod tests {

Expand Down
18 changes: 18 additions & 0 deletions cuda/src/kernels/nn/leaky_relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ impl LeakyRelu {
Ok(())
}
}

pub fn cuda_leaky_relu_dispatch(
alpha: f32,
input: &DeviceTensor,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_cuda_stream(|stream| LeakyRelu.dispatch_eval(stream, input, alpha, output))
}

// LeakyRelu is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId.
crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |_source, _node, op| {
rule_if_some!(leaky = op.0.downcast_ref::<tract_core::ops::nn::LeakyRelu>());
Ok(Some(Box::new(tract_gpu::ops::leaky_relu::GpuLeakyRelu::new(
leaky.alpha,
"Cuda",
cuda_leaky_relu_dispatch,
))))
});
8 changes: 6 additions & 2 deletions cuda/src/kernels/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ mod rms_norm;
mod scaled_masked_softmax;
mod softmax;

pub use apply_rope::ApplyRope;
pub use apply_rope::{ApplyRope, cuda_apply_rope_dispatch};
pub use gelu_approximate::GeluApproximate;
pub use gelu_approximate::cuda_gelu_approximate_dispatch;
pub use leaky_relu::LeakyRelu;
pub use leaky_relu::cuda_leaky_relu_dispatch;
pub use reduce::{Reducer, cuda_reduce_launch};
pub use rms_norm::RmsNorm;
pub use scaled_masked_softmax::ScaledMaskedSoftmax;
pub use rms_norm::cuda_rms_norm_dispatch;
pub use scaled_masked_softmax::{ScaledMaskedSoftmax, cuda_scaled_masked_softmax_dispatch};
pub use softmax::Softmax;
pub use softmax::cuda_softmax_dispatch;

use crate::kernels::{BroadcastKind, MAX_THREADS};

Expand Down
12 changes: 12 additions & 0 deletions cuda/src/kernels/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ pub fn cuda_reduce_launch(
})
}

crate::register_cuda_op!(tract_core::ops::nn::Reduce, |source, node, op| {
let dt = source.node_input_facts(node.id)?[0].datum_type;
if let Ok(gpu_op) =
tract_gpu::ops::reduce::GpuReduce::from_tract_core(op, "Cuda", cuda_reduce_launch)
{
if gpu_op.reducer.is_supported_dt(dt) {
return Ok(Some(Box::new(gpu_op)));
}
}
Ok(None)
});

#[cfg(test)]
mod tests {

Expand Down
19 changes: 19 additions & 0 deletions cuda/src/kernels/nn/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ impl RmsNorm {
}
}

pub fn cuda_rms_norm_dispatch(
input: &DeviceTensor,
axis: usize,
eps: &Tensor,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_cuda_stream(|stream| RmsNorm.dispatch_eval(stream, input, axis, eps, output))
}

crate::register_cuda_op!(tract_transformers::ops::rms_norm::RmsNorm, |source, node, op| {
rule_if!(RmsNorm::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type));
Ok(Some(Box::new(tract_gpu::ops::rms_norm::GpuRmsNorm::new(
op.axis,
op.eps.clone(),
"Cuda",
cuda_rms_norm_dispatch,
))))
});

#[cfg(test)]
mod tests {
use tract_gpu::tensor::IntoDevice;
Expand Down
Loading
Loading