diff --git a/.github/workflows/crates.yml b/.github/workflows/crates.yml index 98645c0ef1..ff7c35fca4 100644 --- a/.github/workflows/crates.yml +++ b/.github/workflows/crates.yml @@ -25,11 +25,11 @@ jobs: if [ "$FULL" == "true" ] then echo 'os=["ubuntu-latest", "macos-latest"]' >> $GITHUB_OUTPUT - echo 'rust=["1.89.0", "stable", "beta", "nightly"]' >> $GITHUB_OUTPUT + echo 'rust=["1.91.0", "stable", "beta", "nightly"]' >> $GITHUB_OUTPUT else echo ::notice::Skipping macOS checks on PR and commit. Dispatch workflow manually if needed. echo 'os=["ubuntu-latest"]' >> $GITHUB_OUTPUT - echo 'rust=["1.89.0"]' >> $GITHUB_OUTPUT + echo 'rust=["1.91.0"]' >> $GITHUB_OUTPUT fi crates: diff --git a/.github/workflows/cross-platform.yml b/.github/workflows/cross-platform.yml index 1002fc2b37..00e27f0f1f 100644 --- a/.github/workflows/cross-platform.yml +++ b/.github/workflows/cross-platform.yml @@ -9,7 +9,7 @@ on: env: CARGO_INCREMENTAL: false FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true - RUSTUP_TOOLCHAIN: 1.89.0 + RUSTUP_TOOLCHAIN: 1.91.0 jobs: linux: diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index e432ea4a90..f03a302173 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -8,7 +8,7 @@ on: env: CARGO_INCREMENTAL: false FORCE_JAVASCRIPT_ACTIONS_TO_NODE20: true - RUSTUP_TOOLCHAIN: 1.89.0 + RUSTUP_TOOLCHAIN: 1.91.0 jobs: examples: diff --git a/.travis/ci-system-setup.sh b/.travis/ci-system-setup.sh index fdac566ef2..0aa6d2cf92 100755 --- a/.travis/ci-system-setup.sh +++ b/.travis/ci-system-setup.sh @@ -5,7 +5,7 @@ set -e if [ -z "$RUSTUP_TOOLCHAIN" ] then - export RUSTUP_TOOLCHAIN=1.89.0 + export RUSTUP_TOOLCHAIN=1.91.0 fi export RUSTUP_TOOLCHAIN diff --git a/.travis/native.sh b/.travis/native.sh index eedf584667..56748b17c5 100755 --- a/.travis/native.sh +++ b/.travis/native.sh @@ -4,7 +4,7 @@ set -ex if [ -z "$RUSTUP_TOOLCHAIN" ] then - export RUSTUP_TOOLCHAIN=1.89.0 + export RUSTUP_TOOLCHAIN=1.91.0 fi rustup update diff --git a/CHANGELOG.md b/CHANGELOG.md index bc421c0667..39b3190a29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# Unreleased + +* [Breaking][MSRV] MSRV bumped to 1.91.0 (for `const TypeId::of`). + # 0.23.0-dev.3 — 2026-03-20 ### Breaking changes diff --git a/Cargo.toml b/Cargo.toml index 95e04ee9b5..f059e36639 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ default-members = [ ] [workspace.package] -rust-version = "1.89" +rust-version = "1.91" [workspace.dependencies] accelerate-src = "0.3" diff --git a/README.md b/README.md index ef8ebfa940..a281af93e2 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ![tract-logo](assets/tract-logo/PNG/tract-horizontal-blue.png) ![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white) -![rustc >= 1.89.0](https://img.shields.io/badge/rustc-%3E%3D1.89.0-brightgreen) +![rustc >= 1.91.0](https://img.shields.io/badge/rustc-%3E%3D1.91.0-brightgreen) ![MIT/Apache 2](https://img.shields.io/crates/l/tract) [![Native Linux test status](https://github.com/snipsco/tract/workflows/Native%20Linux/badge.svg)](https://github.com/snipsco/tract/actions) [![Embedded targets status](https://github.com/snipsco/tract/workflows/Embedded%20targets/badge.svg)](https://github.com/snipsco/tract/actions) diff --git a/cuda/src/context.rs b/cuda/src/context.rs index 4961c7b167..24091aa403 100644 --- a/cuda/src/context.rs +++ b/cuda/src/context.rs @@ -294,6 +294,27 @@ impl DeviceContext for TractCudaContext { ) -> TractResult> { Ok(Box::new(CudaTensor::uninitialized_exotic(exotic_fact)?)) } + + fn copy_nd( + &self, + input: &DeviceTensor, + input_offset: usize, + input_strides: &[isize], + output: &DeviceTensor, + output_offset: usize, + output_shape: &[usize], + output_strides: &[isize], + ) -> TractResult<()> { + crate::kernels::array::cuda_copy_nd_dispatch( + input, + input_offset, + input_strides, + output, + output_offset, + output_shape, + output_strides, + ) + } } /// A recorded GPU kernel timing entry: start/end events tagged with a node_id. diff --git a/cuda/src/kernels/array/cast.rs b/cuda/src/kernels/array/cast.rs index 3ec570e1e9..93a1110aaa 100644 --- a/cuda/src/kernels/array/cast.rs +++ b/cuda/src/kernels/array/cast.rs @@ -89,6 +89,14 @@ impl Cast { } } +pub fn cuda_cast_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> { + crate::with_cuda_stream(|stream| Cast.dispatch_eval(stream, input, output)) +} + +crate::register_cuda_op!(tract_core::ops::cast::Cast, |_source, _node, op| { + Ok(crate::transform::cuda_cast_new(op.to).map(|c| Box::new(c) as _)) +}); + #[cfg(test)] mod tests { diff --git a/cuda/src/kernels/array/mod.rs b/cuda/src/kernels/array/mod.rs index 2f52ee31a5..309b21d06f 100644 --- a/cuda/src/kernels/array/mod.rs +++ b/cuda/src/kernels/array/mod.rs @@ -4,9 +4,11 @@ mod dispatch; mod rotate_half; pub use cast::Cast; +pub use cast::cuda_cast_dispatch; pub use copy::Memcpy; pub use dispatch::cuda_copy_nd_dispatch; pub use rotate_half::RotateHalf; +pub use rotate_half::cuda_rotate_half_dispatch; pub fn all_functions() -> Vec { use std::collections::HashSet; diff --git a/cuda/src/kernels/array/rotate_half.rs b/cuda/src/kernels/array/rotate_half.rs index e297e57bfd..62cf3e5322 100644 --- a/cuda/src/kernels/array/rotate_half.rs +++ b/cuda/src/kernels/array/rotate_half.rs @@ -82,6 +82,18 @@ impl RotateHalf { } } +pub fn cuda_rotate_half_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> { + crate::with_cuda_stream(|stream| RotateHalf.dispatch_eval(stream, input, output)) +} + +crate::register_cuda_op!(tract_transformers::ops::apply_rope::RotateHalf, |source, node, _op| { + rule_if!(RotateHalf::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::rotate_half::GpuRotateHalf::new( + "Cuda", + cuda_rotate_half_dispatch, + )))) +}); + #[cfg(test)] mod tests { diff --git a/cuda/src/kernels/binary.rs b/cuda/src/kernels/binary.rs index 14f7f0a4a1..8027ccf449 100644 --- a/cuda/src/kernels/binary.rs +++ b/cuda/src/kernels/binary.rs @@ -123,6 +123,19 @@ pub fn cuda_bin_op_dispatch( crate::with_cuda_stream(|stream| dispatch_eval(stream, mini_op, lhs, rhs, output)) } +pub fn cuda_bin_op(mini_op: Box) -> tract_gpu::ops::binary::GpuBinOp { + tract_gpu::ops::binary::GpuBinOp { + backend_name: "Cuda", + mini_op, + dispatch: cuda_bin_op_dispatch, + } +} + +crate::register_cuda_op!(tract_core::ops::binary::TypedBinOp, |source, node, op| { + rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(cuda_bin_op(op.0.clone())))) +}); + #[cfg(test)] mod tests { use tract_gpu::tensor::IntoDevice; diff --git a/cuda/src/kernels/element_wise.rs b/cuda/src/kernels/element_wise.rs index 0f464291ee..c3c4f3b25e 100644 --- a/cuda/src/kernels/element_wise.rs +++ b/cuda/src/kernels/element_wise.rs @@ -99,6 +99,22 @@ pub fn cuda_element_wise_dispatch( crate::with_cuda_stream(|stream| dispatch_eval(stream, mini_op, input, output)) } +pub fn cuda_element_wise_op( + mini_op: Box, +) -> tract_gpu::ops::element_wise::GpuElementWise { + tract_gpu::ops::element_wise::GpuElementWise { + backend_name: "Cuda", + mini_op, + dispatch: cuda_element_wise_dispatch, + } +} + +// Generic element-wise fallback — checked after LeakyRelu, GeluApproximate. +crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| { + rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(cuda_element_wise_op(op.0.clone())))) +}); + #[cfg(test)] mod tests { use super::*; diff --git a/cuda/src/kernels/iff.rs b/cuda/src/kernels/iff.rs index a04d3b845d..2a347df108 100644 --- a/cuda/src/kernels/iff.rs +++ b/cuda/src/kernels/iff.rs @@ -30,7 +30,7 @@ impl Iff { } pub fn kernel_name(&self, dt: DatumType, variant: &str) -> TractResult { - Ok(format!("iff_{variant}_{}", DeviceTensor::tname(dt)?)) + Ok(format!("iff_{variant}_{}", tract_gpu::utils::BroadcastKind::copy_tname(dt))) } pub fn eval( @@ -124,3 +124,53 @@ impl Iff { Ok(()) } } + +pub fn cuda_iff_dispatch( + cond: &DeviceTensor, + then_value: &DeviceTensor, + else_value: &DeviceTensor, + cond_strides: &[isize], + then_strides: &[isize], + else_strides: &[isize], + output: &DeviceTensor, + output_shape: &[usize], + output_strides: &[isize], +) -> TractResult<()> { + crate::with_cuda_stream(|stream| { + let total_elems: usize = output_shape.iter().product(); + let block_dim = (128_u32, 1, 1); + let grid_dim = (total_elems.div_ceil(block_dim.0 as usize) as u32, 1, 1); + + let kernel_name = format!( + "iff_generic_{}", + tract_gpu::utils::BroadcastKind::copy_tname(output.datum_type()) + ); + let func = cuda_context().load_pipeline(LibraryName::Binary, kernel_name)?; + let cfg = LaunchConfig { grid_dim, block_dim, shared_mem_bytes: 0 }; + + let cond_view = get_cuda_view(cond); + let then_view = get_cuda_view(then_value); + let else_view = get_cuda_view(else_value); + let o_view = get_cuda_view(output); + + let mut launch_args = TractLaunchArgs::new(stream, &func); + launch_args.push_view(&cond_view); + launch_args.push_view(&then_view); + launch_args.push_view(&else_view); + launch_args.push_view(&o_view); + launch_args.push_slice_i32(output_shape); + launch_args.push_slice_i32(cond_strides); + launch_args.push_slice_i32(then_strides); + launch_args.push_slice_i32(else_strides); + launch_args.push_slice_i32(output_strides); + + launch_args.launch(cfg) + }) +} + +crate::register_cuda_op!(tract_core::ops::logic::Iff, |_source, _node, _op| { + Ok(Some(Box::new(tract_gpu::ops::iff::GpuIff { + backend_name: "Cuda", + dispatch: cuda_iff_dispatch, + }))) +}); diff --git a/cuda/src/kernels/nn/apply_rope.rs b/cuda/src/kernels/nn/apply_rope.rs index bf68119868..e49c1ea8e3 100644 --- a/cuda/src/kernels/nn/apply_rope.rs +++ b/cuda/src/kernels/nn/apply_rope.rs @@ -130,6 +130,23 @@ impl ApplyRope { } } +pub fn cuda_apply_rope_dispatch( + input: &DeviceTensor, + cos: &DeviceTensor, + sin: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| ApplyRope.dispatch_eval(stream, input, cos, sin, output)) +} + +crate::register_cuda_op!(tract_transformers::ops::apply_rope::ApplyRope, |source, node, _op| { + rule_if!(ApplyRope::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::apply_rope::GpuApplyRope { + backend_name: "Cuda", + dispatch: cuda_apply_rope_dispatch, + }))) +}); + #[cfg(test)] mod tests { use std::f32::consts::PI; diff --git a/cuda/src/kernels/nn/gelu_approximate.rs b/cuda/src/kernels/nn/gelu_approximate.rs index 043de3f2fe..cbeea04c96 100644 --- a/cuda/src/kernels/nn/gelu_approximate.rs +++ b/cuda/src/kernels/nn/gelu_approximate.rs @@ -74,6 +74,29 @@ impl GeluApproximate { } } +pub fn cuda_gelu_approximate_dispatch( + fast_impl: bool, + input: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| { + GeluApproximate { fast_impl }.dispatch_eval(stream, input, output) + }) +} + +// GeluApproximate is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId. +crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| { + rule_if_some!( + ew = op.0.downcast_ref::() + ); + rule_if!(GeluApproximate::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::gelu_approximate::GpuGeluApproximate::new( + ew.fast_impl, + "Cuda", + cuda_gelu_approximate_dispatch, + )))) +}); + #[cfg(test)] mod tests { diff --git a/cuda/src/kernels/nn/leaky_relu.rs b/cuda/src/kernels/nn/leaky_relu.rs index bb26b2c282..ce6447d8cf 100644 --- a/cuda/src/kernels/nn/leaky_relu.rs +++ b/cuda/src/kernels/nn/leaky_relu.rs @@ -62,3 +62,21 @@ impl LeakyRelu { Ok(()) } } + +pub fn cuda_leaky_relu_dispatch( + alpha: f32, + input: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| LeakyRelu.dispatch_eval(stream, input, alpha, output)) +} + +// LeakyRelu is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId. +crate::register_cuda_op!(tract_core::ops::element_wise::ElementWiseOp, |_source, _node, op| { + rule_if_some!(leaky = op.0.downcast_ref::()); + Ok(Some(Box::new(tract_gpu::ops::leaky_relu::GpuLeakyRelu::new( + leaky.alpha, + "Cuda", + cuda_leaky_relu_dispatch, + )))) +}); diff --git a/cuda/src/kernels/nn/mod.rs b/cuda/src/kernels/nn/mod.rs index b6923bfe9b..668fad0763 100644 --- a/cuda/src/kernels/nn/mod.rs +++ b/cuda/src/kernels/nn/mod.rs @@ -6,13 +6,17 @@ mod rms_norm; mod scaled_masked_softmax; mod softmax; -pub use apply_rope::ApplyRope; +pub use apply_rope::{ApplyRope, cuda_apply_rope_dispatch}; pub use gelu_approximate::GeluApproximate; +pub use gelu_approximate::cuda_gelu_approximate_dispatch; pub use leaky_relu::LeakyRelu; +pub use leaky_relu::cuda_leaky_relu_dispatch; pub use reduce::{Reducer, cuda_reduce_launch}; pub use rms_norm::RmsNorm; -pub use scaled_masked_softmax::ScaledMaskedSoftmax; +pub use rms_norm::cuda_rms_norm_dispatch; +pub use scaled_masked_softmax::{ScaledMaskedSoftmax, cuda_scaled_masked_softmax_dispatch}; pub use softmax::Softmax; +pub use softmax::cuda_softmax_dispatch; use crate::kernels::{BroadcastKind, MAX_THREADS}; diff --git a/cuda/src/kernels/nn/reduce.rs b/cuda/src/kernels/nn/reduce.rs index 5e27294d2c..2adea959b3 100644 --- a/cuda/src/kernels/nn/reduce.rs +++ b/cuda/src/kernels/nn/reduce.rs @@ -62,6 +62,18 @@ pub fn cuda_reduce_launch( }) } +crate::register_cuda_op!(tract_core::ops::nn::Reduce, |source, node, op| { + let dt = source.node_input_facts(node.id)?[0].datum_type; + if let Ok(gpu_op) = + tract_gpu::ops::reduce::GpuReduce::from_tract_core(op, "Cuda", cuda_reduce_launch) + { + if gpu_op.reducer.is_supported_dt(dt) { + return Ok(Some(Box::new(gpu_op))); + } + } + Ok(None) +}); + #[cfg(test)] mod tests { diff --git a/cuda/src/kernels/nn/rms_norm.rs b/cuda/src/kernels/nn/rms_norm.rs index 5178f8b0a9..1689cb05c2 100644 --- a/cuda/src/kernels/nn/rms_norm.rs +++ b/cuda/src/kernels/nn/rms_norm.rs @@ -77,6 +77,25 @@ impl RmsNorm { } } +pub fn cuda_rms_norm_dispatch( + input: &DeviceTensor, + axis: usize, + eps: &Tensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| RmsNorm.dispatch_eval(stream, input, axis, eps, output)) +} + +crate::register_cuda_op!(tract_transformers::ops::rms_norm::RmsNorm, |source, node, op| { + rule_if!(RmsNorm::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::rms_norm::GpuRmsNorm::new( + op.axis, + op.eps.clone(), + "Cuda", + cuda_rms_norm_dispatch, + )))) +}); + #[cfg(test)] mod tests { use tract_gpu::tensor::IntoDevice; diff --git a/cuda/src/kernels/nn/scaled_masked_softmax.rs b/cuda/src/kernels/nn/scaled_masked_softmax.rs index dc29cefbfb..301de5522a 100644 --- a/cuda/src/kernels/nn/scaled_masked_softmax.rs +++ b/cuda/src/kernels/nn/scaled_masked_softmax.rs @@ -107,6 +107,32 @@ fn pad(vals: &[impl AsPrimitive], neutral: i32) -> [i32; 5] { it } +pub fn cuda_scaled_masked_softmax_dispatch( + input: &DeviceTensor, + scale: &Tensor, + mask: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| { + ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, output) + }) +} + +crate::register_cuda_op!( + tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax, + |source, node, op| { + rule_if!(!op.post_softmax_mask); + rule_if!(ScaledMaskedSoftmax::is_supported_dt( + source.node_input_facts(node.id)?[0].datum_type + )); + Ok(Some(Box::new(tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax { + scale: op.scale.clone(), + backend_name: "Cuda", + dispatch: cuda_scaled_masked_softmax_dispatch, + }))) + } +); + #[cfg(test)] mod tests { use tract_gpu::tensor::IntoDevice; diff --git a/cuda/src/kernels/nn/softmax.rs b/cuda/src/kernels/nn/softmax.rs index 355f37ce17..aa10025e22 100644 --- a/cuda/src/kernels/nn/softmax.rs +++ b/cuda/src/kernels/nn/softmax.rs @@ -73,6 +73,23 @@ impl Softmax { } } +pub fn cuda_softmax_dispatch( + input: &DeviceTensor, + axis: usize, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_cuda_stream(|stream| Softmax.dispatch_eval(stream, input, axis, output)) +} + +crate::register_cuda_op!(tract_core::ops::nn::Softmax, |source, node, op| { + rule_if!(Softmax::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::softmax::GpuSoftmax::from_tract_core( + op, + "Cuda", + cuda_softmax_dispatch, + )?))) +}); + #[cfg(test)] mod tests { diff --git a/cuda/src/ops/apply_rope.rs b/cuda/src/ops/apply_rope.rs deleted file mode 100644 index c83ffad7a0..0000000000 --- a/cuda/src/ops/apply_rope.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::kernels::nn::ApplyRope; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct CudaApplyRope; - -impl Op for CudaApplyRope { - fn name(&self) -> StaticName { - "CudaApplyRope".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaApplyRope { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - let (input_val, cos_val, sin_val) = args_3!(inputs); - let input = input_val.to_device_tensor()?; - let cos = cos_val.to_device_tensor()?; - let sin = sin_val.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - - crate::with_cuda_stream(|stream| { - ApplyRope.dispatch_eval(stream, input, cos, sin, &output) - })?; - Ok(tvec!(output.into_tensor().into_tvalue())) - } -} - -impl TypedOp for CudaApplyRope { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/cast.rs b/cuda/src/ops/cast.rs deleted file mode 100644 index e9bbb4120f..0000000000 --- a/cuda/src/ops/cast.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::kernels; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct CudaCast { - pub to: DatumType, -} - -impl CudaCast { - pub fn is_supported_dt(dt: DatumType) -> bool { - kernels::array::Cast::is_supported_dt(dt) - } - - pub fn new(to: DatumType) -> Option { - Self::is_supported_dt(to).then_some(Self { to }) - } -} - -impl Op for CudaCast { - fn name(&self) -> StaticName { - "CudaCast".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaCast { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - if input.datum_type() == self.to { - Ok(tvec!(input_value)) - } else { - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - self.to, - input.shape(), - )?; - crate::with_cuda_stream(|stream| { - kernels::array::Cast.dispatch_eval(stream, input, &output) - })?; - Ok(tvec![output.into_tensor().into_tvalue()]) - } - } -} - -impl TypedOp for CudaCast { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - Ok(tvec!(self.to.fact(facts[0].shape.clone()))) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/conv.rs b/cuda/src/ops/conv.rs index 946a41a704..2bf9df7791 100644 --- a/cuda/src/ops/conv.rs +++ b/cuda/src/ops/conv.rs @@ -39,16 +39,12 @@ pub fn wire_cuda_conv( needed_shape[data_shape.c_axis()] = op.pool_spec.output_channels.to_dim(); let reshaped = target.wire_node( format!("{prefix}.bias_reshaped"), - GpuAxisOp::new( - AxisOp::Reshape(0, bias.shape.to_tvec(), needed_shape), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ), + GpuAxisOp::new(AxisOp::Reshape(0, bias.shape.to_tvec(), needed_shape)), &[inputs[2]], )?[0]; conv_wire = target.wire_node( prefix, - crate::transform::cuda_bin_op(Box::new(tract_core::ops::math::Add)), + crate::kernels::binary::cuda_bin_op(Box::new(tract_core::ops::math::Add)), &[conv_wire, reshaped], )?[0]; } diff --git a/cuda/src/ops/gelu_approximate.rs b/cuda/src/ops/gelu_approximate.rs deleted file mode 100644 index 837b234bce..0000000000 --- a/cuda/src/ops/gelu_approximate.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::kernels::nn::GeluApproximate; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)] -pub struct CudaGeluApproximate { - pub fast_impl: bool, -} - -impl Op for CudaGeluApproximate { - fn name(&self) -> StaticName { - "CudaGeluApproximate".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaGeluApproximate { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let input = args_1!(inputs); - let input_cuda = input.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input_cuda.datum_type(), - input_cuda.shape(), - )?; - GeluApproximate { fast_impl: self.fast_impl } - .dispatch_eval(stream, input_cuda, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaGeluApproximate { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/leaky_relu.rs b/cuda/src/ops/leaky_relu.rs deleted file mode 100644 index bcf9a9889b..0000000000 --- a/cuda/src/ops/leaky_relu.rs +++ /dev/null @@ -1,57 +0,0 @@ -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -use crate::kernels::nn::LeakyRelu; - -#[derive(Debug, Clone, Default, PartialEq)] -pub struct CudaLeakyRelu { - pub alpha: f32, -} -impl Eq for CudaLeakyRelu {} - -impl Op for CudaLeakyRelu { - fn name(&self) -> StaticName { - "CudaLeakyRelu".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaLeakyRelu { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let input = args_1!(inputs); - let input_cuda = input.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input_cuda.datum_type(), - input_cuda.shape(), - )?; - LeakyRelu.dispatch_eval(stream, input_cuda, self.alpha, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaLeakyRelu { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/mod.rs b/cuda/src/ops/mod.rs index 5d2569d7d6..007469cba4 100644 --- a/cuda/src/ops/mod.rs +++ b/cuda/src/ops/mod.rs @@ -1,28 +1,12 @@ -mod apply_rope; -mod cast; mod conv; mod flash_attn; mod fused_axis_op; -mod gelu_approximate; mod gemm; mod iff; -mod leaky_relu; mod quant_q81; -mod rms_norm; -mod rotate_half; -mod scaled_masked_softmax; -mod softmax; -pub use apply_rope::CudaApplyRope; -pub use cast::CudaCast; pub use conv::{CudaConv, wire_cuda_conv}; pub use flash_attn::CudaFlashAttention; pub use fused_axis_op::CudaFusedAxisOp; -pub use gelu_approximate::CudaGeluApproximate; pub use gemm::CudaGgmlGemm; pub use iff::CudaIff; -pub use leaky_relu::CudaLeakyRelu; pub use quant_q81::{CudaGgmlQuantQ81, GgmlQuantQ81Fact}; -pub use rms_norm::CudaRmsNorm; -pub use rotate_half::CudaRotateHalf; -pub use scaled_masked_softmax::CudaScaledMaskedSoftmax; -pub use softmax::CudaSoftmax; diff --git a/cuda/src/ops/rms_norm.rs b/cuda/src/ops/rms_norm.rs deleted file mode 100644 index 1911043ca2..0000000000 --- a/cuda/src/ops/rms_norm.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::kernels::nn::RmsNorm; -use derive_new::new; -use std::sync::Arc; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct CudaRmsNorm { - pub axis: usize, - pub eps: Arc, -} - -impl Op for CudaRmsNorm { - fn name(&self) -> StaticName { - "CudaRmsNorm".into() - } - fn info(&self) -> TractResult> { - Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)]) - } - op_as_typed_op!(); -} - -impl EvalOp for CudaRmsNorm { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - RmsNorm.dispatch_eval(stream, input, self.axis, &self.eps, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaRmsNorm { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/rotate_half.rs b/cuda/src/ops/rotate_half.rs deleted file mode 100644 index 7759c4bc77..0000000000 --- a/cuda/src/ops/rotate_half.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::kernels::array::RotateHalf; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct CudaRotateHalf; - -impl Op for CudaRotateHalf { - fn name(&self) -> StaticName { - "CudaRotateHalf".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaRotateHalf { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - RotateHalf.dispatch_eval(stream, input, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaRotateHalf { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/scaled_masked_softmax.rs b/cuda/src/ops/scaled_masked_softmax.rs deleted file mode 100644 index 42718dcc5d..0000000000 --- a/cuda/src/ops/scaled_masked_softmax.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::kernels::nn::ScaledMaskedSoftmax; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) -/// Only input of rank of 3 is supported -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct CudaScaledMaskedSoftmax { - pub scale: Arc, -} - -impl Op for CudaScaledMaskedSoftmax { - fn name(&self) -> StaticName { - "CudaScaledMaskedSoftmax".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaScaledMaskedSoftmax { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let (input_val, mask_val) = args_2!(inputs); - let input = input_val.to_device_tensor()?; - let mask = mask_val.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - ScaledMaskedSoftmax.dispatch_eval(stream, input, &self.scale, mask, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaScaledMaskedSoftmax { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - ensure!(facts.len() == 2); - let dt = facts[0].datum_type; - ensure!(dt == facts[1].datum_type); - ensure!(facts[0].rank() <= 5); - ensure!(facts[0].rank() >= 2); - ensure!(facts[0].rank() == facts[1].rank()); - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/cuda/src/ops/softmax.rs b/cuda/src/ops/softmax.rs deleted file mode 100644 index 28125f2904..0000000000 --- a/cuda/src/ops/softmax.rs +++ /dev/null @@ -1,103 +0,0 @@ -use crate::kernels::nn::Softmax; -use std::fmt::Debug; -use tract_core::internal::*; -use tract_core::ops::nn as core_ops_nn; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Debug, Clone, Hash, Default, PartialEq, Eq)] -pub struct CudaSoftmax { - pub axes: TVec, -} - -impl CudaSoftmax { - pub fn new(axes: TVec) -> TractResult { - ensure!(axes.len() == 1, "Only one axis of softmax is supported by CudaSoftmax"); - Ok(Self { axes }) - } - - pub fn from_tract_core(core_softmax: &core_ops_nn::Softmax) -> TractResult { - ensure!(core_softmax.quant_output_dt.is_none()); - Self::new(core_softmax.axes.clone()) - } -} - -impl Op for CudaSoftmax { - fn name(&self) -> StaticName { - "CudaSoftmax".into() - } - - fn info(&self) -> TractResult> { - Ok(vec![format!("axes: {:?}", self.axes)]) - } - - op_as_typed_op!(); -} - -impl EvalOp for CudaSoftmax { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_cuda_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - Softmax.dispatch_eval(stream, input, self.axes[0], &output)?; - - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for CudaSoftmax { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - fn axes_mapping( - &self, - inputs: &[&TypedFact], - outputs: &[&TypedFact], - ) -> TractResult { - AxesMapping::natural(inputs, outputs) - } - - fn change_axes( - &self, - model: &TypedModel, - node: &TypedNode, - _io: InOut, - change: &AxisOp, - ) -> TractResult> { - let axes: Option> = - self.axes.iter().map(|it| change.transform_axis(*it)).collect(); - if let Some(axes) = axes { - Ok(Some(AxisChangeConsequence::new( - model, - node, - Some(Box::new(CudaSoftmax { axes })), - change, - ))) - } else { - Ok(None) - } - } - - as_op!(); -} diff --git a/cuda/src/rewrite_rules/fuse_axis_op.rs b/cuda/src/rewrite_rules/fuse_axis_op.rs index 303fd0af25..cc76e2fb68 100644 --- a/cuda/src/rewrite_rules/fuse_axis_op.rs +++ b/cuda/src/rewrite_rules/fuse_axis_op.rs @@ -13,8 +13,8 @@ fn is_supported_axis_op(op: &GpuAxisOp) -> bool { fn can_fuse_move(model: &TypedModel, axis_node: &TypedNode) -> bool { model.single_succ(axis_node.id).unwrap().is_some_and(|node| { node.op_is::() - || node.op_is::() - || node.op_is::() + || node.op_is::() + || node.op_is::() || node.op_is::() || node.op_is::() || node.op_is::() @@ -194,12 +194,7 @@ pub fn fuse_move_axis( } // Reshape are always fusable. Change Move by Reshape if possible - let simpl_op = GpuAxisOp::simplify_axis_op( - axis_op.inner.clone(), - in_shape.dims(), - axis_op.backend_name, - axis_op.dispatch, - ); + let simpl_op = GpuAxisOp::simplify_axis_op(axis_op.inner.clone(), in_shape.dims()); if simpl_op != *axis_op { return Ok(Some(TypedModelPatch::replace_single_op( model, @@ -226,7 +221,7 @@ pub fn fuse_move_axis( let inputs = patch.taps(model, &axis_node.inputs)?; let out = patch.wire_node( format!("{axis_node_name}.fused_move_axis"), - GpuAxisOp::new(new_axis_ops[0].clone(), axis_op.backend_name, axis_op.dispatch), + GpuAxisOp::new(new_axis_ops[0].clone()), &inputs, )?; patch.shunt_outside(model, cursor.id.into(), out[0])?; @@ -243,11 +238,8 @@ pub fn fuse_move_axis( { let mut patch = TypedModelPatch::default(); let inputs = patch.taps(model, &cursor.inputs)?; - let out = patch.wire_node( - cursor.name.clone(), - GpuAxisOp::new(AxisOp::Add(to_1), axis_op.backend_name, axis_op.dispatch), - &inputs, - )?; + let out = + patch.wire_node(cursor.name.clone(), GpuAxisOp::new(AxisOp::Add(to_1)), &inputs)?; patch.shunt_outside(model, axis_node.id.into(), out[0])?; return Ok(Some(patch)); } diff --git a/cuda/src/transform.rs b/cuda/src/transform.rs index 95d036b810..ee588d1108 100644 --- a/cuda/src/transform.rs +++ b/cuda/src/transform.rs @@ -1,40 +1,58 @@ +use std::any::TypeId; +use std::collections::HashMap; +use std::sync::OnceLock; + use crate::context::cuda_context; -use crate::kernels::nn::cuda_reduce_launch; -use crate::ops::CudaIff; -use crate::ops::{CudaLeakyRelu, wire_cuda_conv}; +use crate::ops::wire_cuda_conv; use crate::{kernels, ops, rewrite_rules}; use DatumType::{F16, F32}; use tract_core::dyn_clone::clone_box; use tract_core::internal::*; use tract_core::model::translator::Translate; -use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat}; -use tract_core::ops::binary::{BinMiniOp, TypedBinOp}; -use tract_core::ops::cast::Cast; use tract_core::ops::cnn::conv::rewrite_kernel_conv_in_oihw; use tract_core::ops::cnn::{Conv, rewrite_conv_with_n_axis}; use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul}; -use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; -use tract_core::ops::logic::Iff; -use tract_core::ops::nn::{LeakyRelu, Reduce, Softmax}; -use tract_core::tract_data::itertools::Itertools; +use tract_core::ops::nn::Reduce; use tract_core::tract_linalg::block_quant::Q4_0; use tract_core::transform::ModelTransform; use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt}; -use tract_gpu::ops::reduce::GpuReduce; use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs; use tract_gpu::rewrite_rules::rms_norm::remove_rms_norm_cast; use tract_gpu::sync::{DeviceSyncKind, sync_inputs_if_required, sync_model_outputs_if_required}; use tract_gpu::tensor::{DeviceTensor, IntoDevice}; use tract_gpu::utils::as_quant_fact; -use tract_pulse_opl::ops::{Delay, PulsePad}; -use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf}; -use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache; -use tract_transformers::ops::gelu_approximate::GeluApproximate; -use tract_transformers::ops::rms_norm::RmsNorm; -use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax; use tract_transformers::ops::sdpa::Sdpa; +/// A registered translator that can convert a core op into a CUDA GPU op. +/// Each kernel module submits one (or more) of these via [`register_cuda_op!`]. +pub struct CudaOpTranslator { + pub type_id: TypeId, + pub try_make: fn(&TypedModel, &TypedNode) -> TractResult>>, +} + +inventory::collect!(CudaOpTranslator); + +/// Register a translator for a core op type. The closure receives `(source, node, op)` +/// where `op` is already downcast to `$op_type`. Return `Ok(Some(gpu_op))` to translate, +/// `Ok(None)` to skip. +#[macro_export] +macro_rules! register_cuda_op { + ($op_type:ty, |$source:ident, $node:ident, $op:ident| $body:expr) => { + inventory::submit! { + $crate::transform::CudaOpTranslator { + type_id: std::any::TypeId::of::<$op_type>(), + try_make: |$source, $node| { + let Some($op) = $node.op_as::<$op_type>() else { + return Ok(None); + }; + $body + }, + } + } + }; +} + #[derive(Debug, Default)] pub struct CudaTransform; @@ -100,74 +118,40 @@ impl CudaTransform { } } -fn can_translate_to_cuda_op(source: &TypedModel, node: &TypedNode) -> TractResult { - let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec(); - let input_dts = input_facts.iter().map(|f| f.datum_type).collect_vec(); - - let in_dts_compatible = - input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type)); - - Ok(in_dts_compatible - && (node - .op_as::() - .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type())) - || node.op_as::().is_some_and(|op| op.0.is::()) - || node.op_as::().is_some_and(|op| { - crate::kernels::element_wise::is_supported(&*op.0, input_dts[0]) - }) - || node - .op_as::() - .is_some_and(|op| crate::kernels::binary::is_supported(&*op.0, input_dts[0])) - || node.op_is::() - || node - .op_as::() - .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type())) - || node.op_as::().is_some_and(|op| { - ops::CudaCast::is_supported_dt(input_dts[0]) && ops::CudaCast::new(op.to).is_some() - }) - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_as::().is_some_and(|op| { - GpuReduce::from_tract_core(op, "Cuda", cuda_reduce_launch) - .is_ok_and(|op| op.reducer.is_supported_dt(input_dts[0])) - }) - || node.op_as::().is_some_and(|op| { - kernels::nn::Softmax::is_supported_dt(input_dts[0]) - && ops::CudaSoftmax::from_tract_core(op).is_ok() - }) - || node.op_as::().is_some_and(|op| { - !op.post_softmax_mask - && kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0]) - }) - || node - .op_as::() - .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0])) - || node - .op_as::() - .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0])) - || node - .op_as::() - .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0])) - || node.op_as::().is_some_and(|op| { - op.0.is::() - && kernels::nn::GeluApproximate::is_supported_dt(input_dts[0]) - }) - || node.op_as::().is_some() - || node.op_as::().is_some_and(|op| { - !op.transpose_c - && op.quantize_output.is_none() - && (can_convert_to_cuda_gemm(&input_facts) - || can_convert_to_cuda_gemm(&[ - input_facts[1].clone(), - input_facts[0].clone(), - ])) - }) - || (node.op_is::() && matches!(input_facts[0].datum_type, F16 | F32)))) +/// Looks up the node's op TypeId in the inventory of registered `CudaOpTranslator`s. +/// Returns `Some(gpu_op)` if a translator matches and succeeds, `None` otherwise. +fn try_make_cuda_op( + source: &TypedModel, + node: &TypedNode, +) -> TractResult>> { + type TranslateFn = fn(&TypedModel, &TypedNode) -> TractResult>>; + static MAP: OnceLock>> = OnceLock::new(); + let map = MAP.get_or_init(|| { + let mut m: HashMap> = HashMap::new(); + for t in inventory::iter:: { + m.entry(t.type_id).or_default().push(t.try_make); + } + m + }); + + let input_facts = source.node_input_facts(node.id)?; + if !input_facts.iter().all(|f| DeviceTensor::is_supported_dt(f.datum_type)) { + return Ok(None); + } + + // Copy-based ops are fully generic (no backend-specific dispatch needed). + if let Some(op) = tract_gpu::ops::copy_based::try_make_copy_based_op(source, node)? { + return Ok(Some(op)); + } + + if let Some(fns) = map.get(&(*node.op).type_id()) { + for f in fns { + if let Some(op) = f(source, node)? { + return Ok(Some(op)); + } + } + } + Ok(None) } fn convert_const(op: &Const) -> TractResult { @@ -182,24 +166,13 @@ fn convert_const(op: &Const) -> TractResult { Const::new_with_exotic_fact(cuda_const, Box::new(cuda_fact)) } -use tract_core::ops::element_wise::ElementWiseMiniOp; -use tract_gpu::ops::binary::GpuBinOp; -use tract_gpu::ops::element_wise::GpuElementWise; - -fn cuda_element_wise_op(mini_op: Box) -> GpuElementWise { - GpuElementWise { - backend_name: "Cuda", - mini_op, - dispatch: crate::kernels::element_wise::cuda_element_wise_dispatch, - } -} - -pub fn cuda_bin_op(mini_op: Box) -> GpuBinOp { - GpuBinOp { - backend_name: "Cuda", - mini_op, - dispatch: crate::kernels::binary::cuda_bin_op_dispatch, - } +pub(crate) fn cuda_cast_new(to: DatumType) -> Option { + tract_gpu::ops::cast::GpuCast::new( + to, + "Cuda", + kernels::array::cuda_cast_dispatch, + kernels::array::Cast::is_supported_dt, + ) } fn can_convert_to_cuda_gemm(facts: &[TypedFact]) -> bool { @@ -247,21 +220,18 @@ fn convert_matmul_to_cuda( if transpose_act { let rank = act_fact.rank(); - let perm_act_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ); + let perm_act_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); let perm_act_name = node.name.clone() + ".perm_activs"; *act_outlet = target.wire_node(perm_act_name, perm_act_op, &[*act_outlet])?[0]; } if act_fact.datum_type == DatumType::F16 && as_quant_fact(weight_fact, &Q4_0).is_some() { - let in_cast_op = ops::CudaCast::new(DatumType::F32).unwrap(); + let in_cast_op = cuda_cast_new(DatumType::F32).unwrap(); *act_outlet = target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*act_outlet])?[0]; } else if act_fact.datum_type == DatumType::F16 && weight_fact.datum_type == DatumType::F32 { - let in_cast_op = ops::CudaCast::new(DatumType::F16).unwrap(); + let in_cast_op = cuda_cast_new(DatumType::F16).unwrap(); *weights_outlet = target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*weights_outlet])?[0]; } @@ -270,11 +240,8 @@ fn convert_matmul_to_cuda( ensure!(as_quant_fact(weight_fact, &Q4_0).is_none(), "Cannot transpose Q40 tensor"); let rank = weight_fact.rank(); - let perm_weights_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ); + let perm_weights_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); let perm_weights_name = node.name.clone() + ".perm_weights"; *weights_outlet = target.wire_node(perm_weights_name, perm_weights_op, &[*weights_outlet])?[0]; @@ -297,11 +264,8 @@ fn convert_matmul_to_cuda( .map(|fact| fact.clarify_dt_shape().unwrap().1.len()) .unwrap(); - let perm_out_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ); + let perm_out_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); matmul_output = target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?; } @@ -312,10 +276,10 @@ fn convert_matmul_to_cuda( let expected_dt = model.node_output_facts(node.id)?[0].datum_type; if out_dt != expected_dt { ensure!( - ops::CudaCast::is_supported_dt(out_dt), + kernels::array::Cast::is_supported_dt(out_dt), "Matmul output type cannot be casted to expected type" ); - let cast_op = ops::CudaCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); + let cast_op = cuda_cast_new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); matmul_output = target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)? } @@ -355,11 +319,9 @@ fn convert_sdpa_to_cuda_flash_attn( suffix: &str, ) -> TractResult<()> { if have != want { - *dst = target.wire_node( - name(node_name, suffix), - ops::CudaCast::new(want).unwrap(), - &[*dst], - )?[0]; + *dst = + target.wire_node(name(node_name, suffix), cuda_cast_new(want).unwrap(), &[*dst])? + [0]; } Ok(()) } @@ -372,11 +334,7 @@ fn convert_sdpa_to_cuda_flash_attn( suffix: &str, ) -> TractResult { if fact.rank() == 3 { - let ax = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Add(1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ); + let ax = tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Add(1)); *dst = target.wire_node(name(node_name, suffix), ax, &[*dst])?[0]; Ok(true) } else { @@ -407,11 +365,7 @@ fn convert_sdpa_to_cuda_flash_attn( let m = m_opt.unwrap(); mut_cast(target, &node.name, m, mf.datum_type().unwrap(), DatumType::F16, ".cast_m")?; if mf.rank() != 4 { - let ax = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Add(1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ); + let ax = tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Add(1)); *m = target.wire_node(name(&node.name, ".reshape_m"), ax, &[*m])?[0]; } } @@ -429,21 +383,14 @@ fn convert_sdpa_to_cuda_flash_attn( if added_head_axis { out = target.wire_node( name(&node.name, ".reshape_out"), - tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Rm(1), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - ), + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Rm(1)), &out, )?; } if q_dt != DatumType::F16 { - out = target.wire_node( - name(&node.name, ".cast_out"), - ops::CudaCast::new(q_dt).unwrap(), - &out, - )?; + out = + target.wire_node(name(&node.name, ".cast_out"), cuda_cast_new(q_dt).unwrap(), &out)?; } Ok(out) @@ -457,98 +404,55 @@ impl Translate, TypedFact, Box> for Cud target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { - let translatable = can_translate_to_cuda_op(source, node)?; - - if translatable { + // Special multi-node ops handled first + let input_facts = source.node_input_facts(node.id)?; + if let Some(op) = node.op_as::() { + let facts: Vec = input_facts.iter().map(|f| (*f).clone()).collect(); + if !op.transpose_c + && op.quantize_output.is_none() + && (can_convert_to_cuda_gemm(&facts) + || can_convert_to_cuda_gemm(&[facts[1].clone(), facts[0].clone()])) + { + let mut device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = + convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + } + if let Some(op) = node.op_as::() { let mut device_inputs = sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = + convert_sdpa_to_cuda_flash_attn(source, node, target, &mut device_inputs, op)?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + if let Some(conv) = node.op_as::() { + if input_facts.iter().all(|f| DeviceTensor::is_supported_dt(f.datum_type)) + && matches!(input_facts[0].datum_type, F16 | F32) + { + let device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = wire_cuda_conv(source, node, target, &device_inputs, conv)?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + } + // Const: inline conversion, not a GPU op + if let Some(op) = node.op_as::() { + if DeviceTensor::is_supported_dt(op.val().datum_type()) { + let device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = + target.wire_node(node.name.clone(), convert_const(op)?, &device_inputs)?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + } - let outlet_ids: TVec = if let Some(op) = node.op_as::() { - convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)? - } else if let Some(op) = node.op_as::() { - convert_sdpa_to_cuda_flash_attn(source, node, target, &mut device_inputs, op)? - } else if let Some(conv) = node.op_as::() { - wire_cuda_conv(source, node, target, &device_inputs, conv)? - } else { - let op: Box = if let Some(op) = node.op_as::() { - Box::new(convert_const(op)?) - } else if let Some(op) = node.op_as::() { - if let Some(leaky) = op.0.downcast_ref::() { - Box::new(CudaLeakyRelu { alpha: leaky.alpha }) - } else if let Some(ew) = op.0.downcast_ref::() { - Box::new(ops::CudaGeluApproximate { fast_impl: ew.fast_impl }) - } else { - Box::new(cuda_element_wise_op(op.0.clone())) - } - } else if let Some(op) = node.op_as::() { - Box::new(cuda_bin_op(op.0.clone())) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::broadcast::GpuMultiBroadcastTo::new( - op.shape.clone(), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(ops::CudaCast::new(op.to).unwrap()) - } else if let Some(op) = node.op_as::() { - let in_fact = source.node_input_facts(node.id)?[0]; - Box::new(tract_gpu::ops::change_axes::GpuAxisOp::from_tract_core_with_fact( - op.clone(), - in_fact, - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::slice::GpuSlice::new( - op.clone(), - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::concat::GpuConcat::new( - op.axis, - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::dyn_kv_cache::GpuDynKVCache::from_tract_transformers( - op, - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(GpuReduce::from_tract_core(op, "Cuda", cuda_reduce_launch)?) - } else if let Some(op) = node.op_as::() { - Box::new(ops::CudaSoftmax::from_tract_core(op)?) - } else if let Some(op) = node.op_as::() - && !op.post_softmax_mask - { - Box::new(ops::CudaScaledMaskedSoftmax { scale: op.scale.clone() }) - } else if let Some(_op) = node.op_as::() { - Box::new(ops::CudaRotateHalf) - } else if let Some(_op) = node.op_as::() { - Box::new(ops::CudaApplyRope) - } else if let Some(op) = node.op_as::() { - Box::new(ops::CudaRmsNorm::new(op.axis, op.eps.clone())) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::pulse::GpuDelay::new( - op, - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::pulse::GpuPulsePad::new( - op, - "Cuda", - crate::kernels::array::cuda_copy_nd_dispatch, - )?) - } else if node.op_is::() { - Box::new(CudaIff) - } else { - bail!("Failed to translate a supported CUDA Op") - }; - target.wire_node(node.name.clone(), op, &device_inputs)? - }; + // Single-op translation + if let Some(gpu_op) = try_make_cuda_op(source, node)? { + let device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = target.wire_node(node.name.clone(), gpu_op, &device_inputs)?; sync_model_outputs_if_required(source, node, target, outlet_ids) } else { let cpu_inputs = diff --git a/gpu/src/device.rs b/gpu/src/device.rs index 44c63d7f08..0448d276d7 100644 --- a/gpu/src/device.rs +++ b/gpu/src/device.rs @@ -1,14 +1,14 @@ use std::ffi::c_void; +use std::ops::Range; use std::sync::Mutex; use anyhow::{anyhow, bail}; use downcast_rs::{Downcast, impl_downcast}; use tract_core::dyn_clone; -use tract_core::internal::ExoticFact; -use tract_core::prelude::{DatumType, TractResult}; +use tract_core::internal::*; use tract_core::value::TValue; -use crate::tensor::OwnedDeviceTensor; +use crate::tensor::{DeviceTensor, OwnedDeviceTensor}; pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync { fn tensor_to_device(&self, tensor: TValue) -> TractResult>; @@ -22,6 +22,76 @@ pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync { exotic_fact: Box, ) -> TractResult>; fn synchronize(&self) -> TractResult<()>; + fn copy_nd( + &self, + input: &DeviceTensor, + input_offset: usize, + input_strides: &[isize], + output: &DeviceTensor, + output_offset: usize, + output_shape: &[usize], + output_strides: &[isize], + ) -> TractResult<()>; + + /// Copy a slice along `axis` from `src[src_range]` into `dst[dst_range]`. + fn assign_slice( + &self, + dst: &DeviceTensor, + dst_range: Range, + src: &DeviceTensor, + src_range: Range, + axis: usize, + ) -> TractResult<()> { + let mut zone_shape: TVec = src.shape().into(); + zone_shape[axis] = src_range.len(); + if zone_shape.iter().product::() == 0 { + return Ok(()); + } + let src_offset = + src_range.start * src.strides()[axis] as usize * src.datum_type().size_of(); + let dst_offset = + dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of(); + self.copy_nd(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides()) + } + + /// Copy from `src` into `dst` with given origins and strides. + fn copy_with_origins( + &self, + zone_shape: &[usize], + dst: &DeviceTensor, + dst_origin: &[usize], + dst_strides: &[isize], + src: &DeviceTensor, + src_origin: &[usize], + src_strides: &[isize], + ) -> TractResult<()> { + if zone_shape.iter().product::() == 0 { + return Ok(()); + } + let dt_size = src.datum_type().size_of(); + let src_offset: usize = + src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::() + * dt_size; + let dst_offset: usize = + dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::() + * dt_size; + self.copy_nd(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides) + } + + /// Flat memcpy of `byte_len` bytes. + fn flat_copy( + &self, + src: &DeviceTensor, + src_byte_offset: usize, + dst: &DeviceTensor, + dst_byte_offset: usize, + byte_len: usize, + ) -> TractResult<()> { + if byte_len == 0 { + return Ok(()); + } + self.copy_nd(src, src_byte_offset, &[1], dst, dst_byte_offset, &[byte_len], &[1]) + } } impl_downcast!(DeviceContext); diff --git a/gpu/src/ops/apply_rope.rs b/gpu/src/ops/apply_rope.rs new file mode 100644 index 0000000000..34ef599bb7 --- /dev/null +++ b/gpu/src/ops/apply_rope.rs @@ -0,0 +1,75 @@ +use crate::tensor::{DeviceTensor, DeviceTensorExt}; +use tract_core::internal::*; + +pub type DispatchApplyRopeFn = + fn(&DeviceTensor, &DeviceTensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuApplyRope { + pub backend_name: &'static str, + pub dispatch: DispatchApplyRopeFn, +} + +impl std::fmt::Debug for GpuApplyRope { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}ApplyRope", self.backend_name) + } +} + +impl PartialEq for GpuApplyRope { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name + } +} +impl Eq for GpuApplyRope {} + +impl std::hash::Hash for GpuApplyRope { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + } +} + +impl Op for GpuApplyRope { + fn name(&self) -> StaticName { + format!("{}ApplyRope", self.backend_name).into() + } + op_as_typed_op!(); +} + +impl EvalOp for GpuApplyRope { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let (input_val, cos_val, sin_val) = args_3!(inputs); + let input = input_val.to_device_tensor()?; + let cos = cos_val.to_device_tensor()?; + let sin = sin_val.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(input, cos, sin, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuApplyRope { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + as_op!(); +} diff --git a/gpu/src/ops/broadcast.rs b/gpu/src/ops/broadcast.rs index f279ed5889..02f3a69b75 100644 --- a/gpu/src/ops/broadcast.rs +++ b/gpu/src/ops/broadcast.rs @@ -1,39 +1,21 @@ use crate::tensor::DeviceTensorExt; -use crate::utils::{DispatchCopyNdFn, compute_broadcast_strides}; -use derive_new::new; +use crate::utils::compute_broadcast_strides; use tract_core::internal::*; -#[derive(Clone, new)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct GpuMultiBroadcastTo { pub shape: ShapeFact, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } -impl std::fmt::Debug for GpuMultiBroadcastTo { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}MultiBroadcastTo({:?})", self.backend_name, self.shape) - } -} - -impl PartialEq for GpuMultiBroadcastTo { - fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name && self.shape == other.shape - } -} - -impl Eq for GpuMultiBroadcastTo {} - -impl std::hash::Hash for GpuMultiBroadcastTo { - fn hash(&self, state: &mut H) { - self.backend_name.hash(state); - self.shape.hash(state); +impl GpuMultiBroadcastTo { + pub fn new(shape: ShapeFact) -> Self { + Self { shape } } } impl Op for GpuMultiBroadcastTo { fn name(&self) -> StaticName { - format!("{}MultiBroadcastTo", self.backend_name).into() + "GpuMultiBroadcastTo".into() } op_as_typed_op!(); @@ -68,15 +50,8 @@ impl EvalOp for GpuMultiBroadcastTo { let broadcast_strides: TVec = compute_broadcast_strides(&input_shape, &input_strides)?; - (self.dispatch)( - input, - 0, - &broadcast_strides, - &output, - 0, - output.shape(), - output.strides(), - )?; + let ctx = crate::device::get_context()?; + ctx.copy_nd(input, 0, &broadcast_strides, &output, 0, output.shape(), output.strides())?; Ok(tvec![output.into_tensor().into_tvalue()]) } } diff --git a/gpu/src/ops/cast.rs b/gpu/src/ops/cast.rs new file mode 100644 index 0000000000..dea883fd6f --- /dev/null +++ b/gpu/src/ops/cast.rs @@ -0,0 +1,97 @@ +use crate::tensor::DeviceTensorExt; +use tract_core::internal::*; + +use crate::tensor::DeviceTensor; + +pub type DispatchCastFn = fn(&DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuCast { + pub to: DatumType, + pub backend_name: &'static str, + pub dispatch: DispatchCastFn, + pub is_supported_dt: fn(DatumType) -> bool, +} + +impl GpuCast { + pub fn new( + to: DatumType, + backend_name: &'static str, + dispatch: DispatchCastFn, + is_supported_dt: fn(DatumType) -> bool, + ) -> Option { + is_supported_dt(to).then_some(Self { to, backend_name, dispatch, is_supported_dt }) + } + + pub fn is_supported_dt(&self, dt: DatumType) -> bool { + (self.is_supported_dt)(dt) + } +} + +impl std::fmt::Debug for GpuCast { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}Cast({:?})", self.backend_name, self.to) + } +} + +impl PartialEq for GpuCast { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.to == other.to + } +} + +impl Eq for GpuCast {} + +impl std::hash::Hash for GpuCast { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.to.hash(state); + } +} + +impl Op for GpuCast { + fn name(&self) -> StaticName { + format!("{}Cast", self.backend_name).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuCast { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + if input.datum_type() == self.to { + Ok(tvec!(input_value)) + } else { + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + self.to, + input.shape(), + )?; + (self.dispatch)(input, &output)?; + Ok(tvec![output.into_tensor().into_tvalue()]) + } + } +} + +impl TypedOp for GpuCast { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + Ok(tvec!(self.to.fact(facts[0].shape.clone()))) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/change_axes.rs b/gpu/src/ops/change_axes.rs index 601e341893..b0e926489a 100644 --- a/gpu/src/ops/change_axes.rs +++ b/gpu/src/ops/change_axes.rs @@ -1,26 +1,18 @@ use crate::tensor::DeviceTensorExt; -use crate::utils::DispatchCopyNdFn; use tract_core::internal::*; use tract_itertools::Itertools; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct GpuAxisOp { pub inner: AxisOp, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuAxisOp { - pub fn new(inner: AxisOp, backend_name: &'static str, dispatch: DispatchCopyNdFn) -> Self { - Self { inner, backend_name, dispatch } + pub fn new(inner: AxisOp) -> Self { + Self { inner } } - pub fn simplify_axis_op( - op: AxisOp, - dims: &[TDim], - backend_name: &'static str, - dispatch: DispatchCopyNdFn, - ) -> Self { + pub fn simplify_axis_op(op: AxisOp, dims: &[TDim]) -> Self { let inner = match op { AxisOp::Move(from, to) if from.abs_diff(to) == 1 => { if [&dims[from], &dims[to]].contains(&&1usize.into()) { @@ -57,57 +49,18 @@ impl GpuAxisOp { } _ => op, }; - Self { inner, backend_name, dispatch } + Self { inner } } - pub fn from_tract_core_with_fact( - op: AxisOp, - fact: &TypedFact, - backend_name: &'static str, - dispatch: DispatchCopyNdFn, - ) -> Self { + pub fn from_tract_core_with_fact(op: AxisOp, fact: &TypedFact) -> Self { let dims = fact.shape.dims(); - Self::simplify_axis_op(op, dims, backend_name, dispatch) - } -} - -impl std::fmt::Debug for GpuAxisOp { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.inner { - AxisOp::Add(a) => write!(f, "{}Add({a})", self.backend_name), - AxisOp::Rm(a) => write!(f, "{}Rm({a})", self.backend_name), - AxisOp::Move(from, to) => write!(f, "{}Move({from}, {to})", self.backend_name), - AxisOp::Reshape(at, from, to) => { - write!( - f, - "{}Reshape({at}, [{}], [{}])", - self.backend_name, - from.iter().join(","), - to.iter().join(",") - ) - } - } - } -} - -impl PartialEq for GpuAxisOp { - fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name && self.inner == other.inner - } -} - -impl Eq for GpuAxisOp {} - -impl std::hash::Hash for GpuAxisOp { - fn hash(&self, state: &mut H) { - self.backend_name.hash(state); - self.inner.hash(state); + Self::simplify_axis_op(op, dims) } } impl Op for GpuAxisOp { fn name(&self) -> StaticName { - format!("{}{}", self.backend_name, self.inner.name()).into() + format!("Gpu{}", self.inner.name()).into() } fn info(&self) -> TractResult> { @@ -135,8 +88,6 @@ impl EvalOp for GpuAxisOp { let simplified = Self::simplify_axis_op( self.inner.clone(), &shape.iter().map(|s| s.into()).collect_vec(), - self.backend_name, - self.dispatch, ); let new_shape = match &simplified.inner { @@ -155,7 +106,8 @@ impl EvalOp for GpuAxisOp { // Compute permuted input strides let permuted_strides: TVec = permutation.iter().map(|&i| input.strides()[i]).collect(); - (self.dispatch)( + let ctx = crate::device::get_context()?; + ctx.copy_nd( input, 0, &permuted_strides, @@ -188,7 +140,8 @@ impl EvalOp for GpuAxisOp { &new_shape, )?; let flat_len = input.len(); - (self.dispatch)(input, 0, &[1], &output, 0, &[flat_len], &[1])?; + let ctx = crate::device::get_context()?; + ctx.copy_nd(input, 0, &[1], &output, 0, &[flat_len], &[1])?; Ok(tvec!(output.into_tensor().into_tvalue())) } } @@ -226,7 +179,7 @@ impl TypedOp for GpuAxisOp { } else { self.inner.clone() }; - let op = GpuAxisOp { inner, backend_name: self.backend_name, dispatch: self.dispatch }; + let op = GpuAxisOp { inner }; target.wire_node(&node.name, op, &[mapping[&node.inputs[0]]]) } diff --git a/gpu/src/ops/concat.rs b/gpu/src/ops/concat.rs index d0d618659e..11bfecb2c0 100644 --- a/gpu/src/ops/concat.rs +++ b/gpu/src/ops/concat.rs @@ -1,17 +1,14 @@ use crate::tensor::DeviceTensorExt; -use crate::utils::DispatchCopyNdFn; use tract_core::internal::*; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct GpuConcat { pub axis: usize, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuConcat { - pub fn new(axis: usize, backend_name: &'static str, dispatch: DispatchCopyNdFn) -> Self { - Self { axis, backend_name, dispatch } + pub fn new(axis: usize) -> Self { + Self { axis } } pub fn offsets(&self, inputs: &[&TypedFact]) -> TractResult> { @@ -25,30 +22,9 @@ impl GpuConcat { } } -impl std::fmt::Debug for GpuConcat { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}Concat(axis={})", self.backend_name, self.axis) - } -} - -impl PartialEq for GpuConcat { - fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name && self.axis == other.axis - } -} - -impl Eq for GpuConcat {} - -impl std::hash::Hash for GpuConcat { - fn hash(&self, state: &mut H) { - self.backend_name.hash(state); - self.axis.hash(state); - } -} - impl Op for GpuConcat { fn name(&self) -> StaticName { - format!("{}Concat", self.backend_name).into() + "GpuConcat".into() } fn info(&self) -> TractResult> { @@ -81,6 +57,7 @@ impl EvalOp for GpuConcat { &output_shape, )?; + let ctx = crate::device::get_context()?; let mut cursor = 0usize; for input in &inputs { let slice_len = input.shape()[self.axis]; @@ -93,7 +70,7 @@ impl EvalOp for GpuConcat { let dst_offset = cursor * output.strides()[self.axis] as usize * output.datum_type().size_of(); - (self.dispatch)( + ctx.copy_nd( input, 0, input.strides(), diff --git a/gpu/src/ops/copy_based.rs b/gpu/src/ops/copy_based.rs new file mode 100644 index 0000000000..92170f38a0 --- /dev/null +++ b/gpu/src/ops/copy_based.rs @@ -0,0 +1,42 @@ +//! Translators for ops that only need the generic copy_nd dispatch. +//! These are fully backend-agnostic and can be constructed without +//! any backend-specific arguments. + +use tract_core::internal::*; +use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat}; +use tract_pulse_opl::ops::{Delay, PulsePad}; +use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache; + +/// Try to translate a node into a copy-based GPU op. +/// Returns `Some(gpu_op)` if the node is one of the 7 copy-based ops. +pub fn try_make_copy_based_op( + source: &TypedModel, + node: &TypedNode, +) -> TractResult>> { + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::broadcast::GpuMultiBroadcastTo::new(op.shape.clone())))); + } + if let Some(op) = node.op_as::() { + let in_fact = source.node_input_facts(node.id)?[0]; + return Ok(Some(Box::new(super::change_axes::GpuAxisOp::from_tract_core_with_fact( + op.clone(), + in_fact, + )))); + } + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::slice::GpuSlice::new(op.clone())))); + } + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::concat::GpuConcat::new(op.axis)))); + } + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::dyn_kv_cache::GpuDynKVCache::from_tract_transformers(op)))); + } + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::pulse::GpuDelay::new(op)))); + } + if let Some(op) = node.op_as::() { + return Ok(Some(Box::new(super::pulse::GpuPulsePad::new(op)?))); + } + Ok(None) +} diff --git a/gpu/src/ops/dyn_kv_cache.rs b/gpu/src/ops/dyn_kv_cache.rs index aa66e00553..cc5380f110 100644 --- a/gpu/src/ops/dyn_kv_cache.rs +++ b/gpu/src/ops/dyn_kv_cache.rs @@ -1,6 +1,5 @@ use crate::fact::DeviceTypedFactExt; use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice}; -use crate::utils::DispatchCopyNdFn; use derive_new::new; use tract_core::internal::*; use tract_core::ops::OpStateFreeze; @@ -12,7 +11,6 @@ pub struct GpuDynKVCacheState { name: String, axis: usize, past_sequence_fact: TypedFact, - dispatch: DispatchCopyNdFn, kv_cache: Option, } @@ -71,7 +69,6 @@ impl OpState for GpuDynKVCacheState { let gpu_op = op.downcast_ref::().ok_or_else(|| format_err!("Wrong Op type"))?; let axis = gpu_op.axis; - let dispatch = gpu_op.dispatch; let inputs = op_inputs.iter().map(|it| it.to_device_tensor()).collect::>>()?; @@ -85,6 +82,7 @@ impl OpState for GpuDynKVCacheState { )?; // Concat inputs into output + let ctx = crate::device::get_context()?; let mut cursor = 0usize; for input in &inputs { let slice_len = input.shape()[axis]; @@ -93,7 +91,7 @@ impl OpState for GpuDynKVCacheState { } let dst_offset = cursor * output.strides()[axis] as usize * output.datum_type().size_of(); - dispatch( + ctx.copy_nd( input, 0, input.strides(), @@ -128,7 +126,6 @@ pub struct FrozenGpuDynKVCacheState { name: String, axis: usize, past_sequence_fact: TypedFact, - dispatch: DispatchCopyNdFn, kv_cache: Option, } @@ -139,7 +136,6 @@ impl OpStateFreeze for GpuDynKVCacheState { name: self.name.clone(), axis: self.axis, past_sequence_fact: self.past_sequence_fact.clone(), - dispatch: self.dispatch, kv_cache: self.kv_cache.clone().map(|t| t.to_device_tensor().cloned().unwrap()), }) } @@ -152,7 +148,6 @@ impl FrozenOpState for FrozenGpuDynKVCacheState { name: self.name.clone(), axis: self.axis, past_sequence_fact: self.past_sequence_fact.clone(), - dispatch: self.dispatch, kv_cache: self.kv_cache.clone().map(|t| t.into_tensor().into_tvalue()), }) } @@ -164,37 +159,28 @@ pub struct GpuDynKVCache { pub past_sequence_fact: TypedFact, pub input_sequence_fact: TypedFact, pub axis: usize, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuDynKVCache { - pub fn from_tract_transformers( - op: &DynKeyValueCache, - backend_name: &'static str, - dispatch: DispatchCopyNdFn, - ) -> Self { + pub fn from_tract_transformers(op: &DynKeyValueCache) -> Self { Self { name: op.name.clone(), axis: op.axis, past_sequence_fact: op.past_sequence_fact.clone(), input_sequence_fact: op.input_sequence_fact.clone(), - backend_name, - dispatch, } } } impl std::fmt::Debug for GpuDynKVCache { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}DynKVCache({}, axis={})", self.backend_name, self.name, self.axis) + write!(f, "GpuDynKVCache({}, axis={})", self.name, self.axis) } } impl PartialEq for GpuDynKVCache { fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name - && self.name == other.name + self.name == other.name && self.axis == other.axis && self.past_sequence_fact == other.past_sequence_fact && self.input_sequence_fact == other.input_sequence_fact @@ -205,7 +191,6 @@ impl Eq for GpuDynKVCache {} impl std::hash::Hash for GpuDynKVCache { fn hash(&self, state: &mut H) { - self.backend_name.hash(state); self.name.hash(state); self.axis.hash(state); } @@ -213,7 +198,7 @@ impl std::hash::Hash for GpuDynKVCache { impl Op for GpuDynKVCache { fn name(&self) -> StaticName { - format!("{}DynKVCache", self.backend_name).into() + "GpuDynKVCache".into() } fn info(&self) -> TractResult> { @@ -234,7 +219,6 @@ impl EvalOp for GpuDynKVCache { self.name.clone(), self.axis, self.past_sequence_fact.clone(), - self.dispatch, None, )))) } diff --git a/gpu/src/ops/gelu_approximate.rs b/gpu/src/ops/gelu_approximate.rs new file mode 100644 index 0000000000..cda3369e25 --- /dev/null +++ b/gpu/src/ops/gelu_approximate.rs @@ -0,0 +1,89 @@ +use crate::tensor::DeviceTensorExt; +use tract_core::internal::*; + +use crate::tensor::DeviceTensor; + +pub type DispatchGeluApproximateFn = fn(bool, &DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuGeluApproximate { + pub fast_impl: bool, + pub backend_name: &'static str, + pub dispatch: DispatchGeluApproximateFn, +} + +impl GpuGeluApproximate { + pub fn new( + fast_impl: bool, + backend_name: &'static str, + dispatch: DispatchGeluApproximateFn, + ) -> Self { + Self { fast_impl, backend_name, dispatch } + } +} + +impl std::fmt::Debug for GpuGeluApproximate { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}GeluApproximate(fast_impl: {})", self.backend_name, self.fast_impl) + } +} + +impl PartialEq for GpuGeluApproximate { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.fast_impl == other.fast_impl + } +} + +impl Eq for GpuGeluApproximate {} + +impl std::hash::Hash for GpuGeluApproximate { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.fast_impl.hash(state); + } +} + +impl Op for GpuGeluApproximate { + fn name(&self) -> StaticName { + format!("{}GeluApproximate", self.backend_name).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuGeluApproximate { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(self.fast_impl, input, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuGeluApproximate { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/iff.rs b/gpu/src/ops/iff.rs new file mode 100644 index 0000000000..aad75dff68 --- /dev/null +++ b/gpu/src/ops/iff.rs @@ -0,0 +1,145 @@ +use crate::tensor::{DeviceTensor, DeviceTensorExt}; +use tract_core::broadcast::multi_broadcast; +use tract_core::internal::*; + +static IFF_MAX_RANK: usize = 5; + +/// Dispatch function for the iff (select) kernel. +/// Args: cond, then, else tensors with pre-computed broadcast strides, +/// output tensor, output shape and strides. All strides are padded to IFF_MAX_RANK. +pub type DispatchIffFn = fn( + cond: &DeviceTensor, + then_value: &DeviceTensor, + else_value: &DeviceTensor, + cond_strides: &[isize], + then_strides: &[isize], + else_strides: &[isize], + output: &DeviceTensor, + output_shape: &[usize], + output_strides: &[isize], +) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuIff { + pub backend_name: &'static str, + pub dispatch: DispatchIffFn, +} + +impl std::fmt::Debug for GpuIff { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}Iff", self.backend_name) + } +} + +impl PartialEq for GpuIff { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name + } +} + +impl Eq for GpuIff {} + +impl std::hash::Hash for GpuIff { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + } +} + +impl Op for GpuIff { + fn name(&self) -> StaticName { + format!("{}Iff", self.backend_name).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuIff { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let (cond_val, then_val, else_val) = args_3!(inputs); + + let cond = cond_val.to_device_tensor()?; + let then_t = then_val.to_device_tensor()?; + let else_t = else_val.to_device_tensor()?; + ensure!(cond.rank() == then_t.rank()); + ensure!(cond.rank() == else_t.rank()); + ensure!(then_t.datum_type() == else_t.datum_type()); + + let out_shape = multi_broadcast(&[cond.shape(), then_t.shape(), else_t.shape()]) + .context("No broadcasting solution found")?; + let out_dt = then_t.datum_type(); + let output = + crate::session_handler::make_tensor_for_node(session, node_id, out_dt, &out_shape)?; + + if output.len() > 0 { + let rank = cond.rank(); + ensure!(rank <= IFF_MAX_RANK); + let rank_pad = IFF_MAX_RANK - rank; + + let mut padded_cond_strides = [0isize; IFF_MAX_RANK]; + let mut padded_then_strides = [0isize; IFF_MAX_RANK]; + let mut padded_else_strides = [0isize; IFF_MAX_RANK]; + let mut padded_out_shape = [1usize; IFF_MAX_RANK]; + let mut padded_out_strides = [0isize; IFF_MAX_RANK]; + + for axis in 0..rank { + padded_out_shape[rank_pad + axis] = output.shape()[axis]; + padded_out_strides[rank_pad + axis] = output.strides()[axis]; + padded_cond_strides[rank_pad + axis] = if cond.shape()[axis] < output.shape()[axis] + { + 0 + } else { + cond.strides()[axis] + }; + padded_then_strides[rank_pad + axis] = + if then_t.shape()[axis] < output.shape()[axis] { + 0 + } else { + then_t.strides()[axis] + }; + padded_else_strides[rank_pad + axis] = + if else_t.shape()[axis] < output.shape()[axis] { + 0 + } else { + else_t.strides()[axis] + }; + } + + (self.dispatch)( + cond, + then_t, + else_t, + &padded_cond_strides, + &padded_then_strides, + &padded_else_strides, + &output, + &padded_out_shape, + &padded_out_strides, + ) + .with_context(|| "Error while dispatching eval for Iff")?; + } + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuIff { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |inputs| { + let out_shape = + multi_broadcast(&[&*inputs[0].shape, &*inputs[1].shape, &*inputs[2].shape]) + .context("No broadcasting solution found")?; + let out_dt = inputs[1].datum_type; + Ok(tvec!(out_dt.fact(out_shape))) + }) + } + + as_op!(); +} diff --git a/gpu/src/ops/leaky_relu.rs b/gpu/src/ops/leaky_relu.rs new file mode 100644 index 0000000000..73be60e37b --- /dev/null +++ b/gpu/src/ops/leaky_relu.rs @@ -0,0 +1,85 @@ +use crate::tensor::DeviceTensorExt; +use tract_core::internal::*; + +use crate::tensor::DeviceTensor; + +pub type DispatchLeakyReluFn = fn(f32, &DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuLeakyRelu { + pub alpha: f32, + pub backend_name: &'static str, + pub dispatch: DispatchLeakyReluFn, +} + +impl GpuLeakyRelu { + pub fn new(alpha: f32, backend_name: &'static str, dispatch: DispatchLeakyReluFn) -> Self { + Self { alpha, backend_name, dispatch } + } +} + +impl std::fmt::Debug for GpuLeakyRelu { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}LeakyRelu(alpha: {})", self.backend_name, self.alpha) + } +} + +impl PartialEq for GpuLeakyRelu { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.alpha == other.alpha + } +} + +impl Eq for GpuLeakyRelu {} + +impl std::hash::Hash for GpuLeakyRelu { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.alpha.to_bits().hash(state); + } +} + +impl Op for GpuLeakyRelu { + fn name(&self) -> StaticName { + format!("{}LeakyRelu", self.backend_name).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuLeakyRelu { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(self.alpha, input, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuLeakyRelu { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/mod.rs b/gpu/src/ops/mod.rs index 668c70bd29..95e84edf89 100644 --- a/gpu/src/ops/mod.rs +++ b/gpu/src/ops/mod.rs @@ -1,9 +1,19 @@ +pub mod apply_rope; pub mod binary; pub mod broadcast; +pub mod cast; pub mod change_axes; pub mod concat; +pub mod copy_based; pub mod dyn_kv_cache; pub mod element_wise; +pub mod gelu_approximate; +pub mod iff; +pub mod leaky_relu; pub mod pulse; pub mod reduce; +pub mod rms_norm; +pub mod rotate_half; +pub mod scaled_masked_softmax; pub mod slice; +pub mod softmax; diff --git a/gpu/src/ops/pulse.rs b/gpu/src/ops/pulse.rs index 5c858eee3f..ce130a87ba 100644 --- a/gpu/src/ops/pulse.rs +++ b/gpu/src/ops/pulse.rs @@ -1,9 +1,7 @@ #![allow(unpredictable_function_pointer_comparisons)] +use crate::device::{DeviceContext, get_context}; use crate::session_handler::make_tensor_for_node; use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice}; -use crate::utils::{ - DispatchCopyNdFn, dispatch_assign_slice, dispatch_copy_with_origins, dispatch_flat_copy, -}; use std::ops::Range; use tract_core::internal::*; use tract_core::ops::array::PadMode; @@ -12,29 +10,20 @@ use tract_pulse_opl::ops::{Delay, PulsePad}; // ─── GpuDelay ──────────────────────────────────────────────────────────────── -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GpuDelay { pub inner: Delay, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuDelay { - pub fn new(inner: &Delay, backend_name: &'static str, dispatch: DispatchCopyNdFn) -> Self { - Self { inner: inner.clone(), backend_name, dispatch } - } -} - -impl std::hash::Hash for GpuDelay { - fn hash(&self, state: &mut H) { - self.backend_name.hash(state); - self.inner.hash(state); + pub fn new(inner: &Delay) -> Self { + Self { inner: inner.clone() } } } impl Op for GpuDelay { fn name(&self) -> StaticName { - format!("{}Delay", self.backend_name).into() + "GpuDelay".into() } fn info(&self) -> TractResult> { @@ -50,7 +39,7 @@ impl EvalOp for GpuDelay { } fn state(&self, _session: &TurnState, node_id: usize) -> TractResult>> { - Ok(Some(Box::new(GpuDelayState { node_id, dispatch: self.dispatch, buffer: None }))) + Ok(Some(Box::new(GpuDelayState { node_id, buffer: None }))) } } @@ -70,15 +59,14 @@ impl TypedOp for GpuDelay { #[derive(Debug, Clone)] pub struct GpuDelayState { pub node_id: usize, - pub dispatch: DispatchCopyNdFn, pub buffer: Option, } impl GpuDelayState { unsafe fn apply_delay_unchecked( &mut self, + ctx: &dyn DeviceContext, op: &Delay, - dispatch: DispatchCopyNdFn, input: &DeviceTensor, output: &mut DeviceTensor, ) -> TractResult<()> { @@ -91,21 +79,13 @@ impl GpuDelayState { let from_buffer = output_pulse.saturating_sub(from_input); // Copy from buffer to output - dispatch_assign_slice(dispatch, output, 0..from_buffer, buffer, 0..from_buffer, op.axis)?; + ctx.assign_slice(output, 0..from_buffer, buffer, 0..from_buffer, op.axis)?; // Copy from input to output - dispatch_assign_slice( - dispatch, - output, - from_buffer..output_pulse, - input, - 0..from_input, - op.axis, - )?; + ctx.assign_slice(output, from_buffer..output_pulse, input, 0..from_input, op.axis)?; // Maintain buffer if buffered < input_pulse { - dispatch_assign_slice( - dispatch, + ctx.assign_slice( buffer, 0..buffered, input, @@ -117,10 +97,9 @@ impl GpuDelayState { let dt = input.datum_type(); let shift_bytes = buffer.strides()[op.axis] as usize * dt.size_of() * input_pulse; let remaining = buffer.len() * dt.size_of() - shift_bytes; - dispatch_flat_copy(dispatch, buffer, shift_bytes, buffer, 0, remaining)?; + ctx.flat_copy(buffer, shift_bytes, buffer, 0, remaining)?; // Copy input to end of buffer - dispatch_assign_slice( - dispatch, + ctx.assign_slice( buffer, (buffered - input_pulse)..buffered, input, @@ -148,6 +127,7 @@ impl OpState for GpuDelayState { let mut output_shape: TVec = device_input.shape().into(); output_shape[op.axis] = output_pulse; let dt = device_input.datum_type(); + let ctx = get_context()?; unsafe { if self.buffer.is_none() { let mut shape = device_input.shape().to_owned(); @@ -155,7 +135,7 @@ impl OpState for GpuDelayState { self.buffer = Some(DeviceTensor::uninitialized_dt(dt, &shape)?); }; let mut output = make_tensor_for_node(state, self.node_id, dt, &output_shape)?; - self.apply_delay_unchecked(op, self.dispatch, device_input, &mut output)?; + self.apply_delay_unchecked(&*ctx, op, device_input, &mut output)?; Ok(tvec!(output.into_tensor().into())) } } @@ -169,32 +149,25 @@ trivial_op_state_freeze!(GpuDelayState); pub struct GpuPulsePad { pub op: PulsePad, pub device_cst: Option, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuPulsePad { - pub fn new( - op: &PulsePad, - backend_name: &'static str, - dispatch: DispatchCopyNdFn, - ) -> TractResult { + pub fn new(op: &PulsePad) -> TractResult { let device_cst = if let PadMode::Constant(c) = &op.mode { Some(c.clone().into_device()?) } else { None }; - Ok(Self { op: op.clone(), device_cst, backend_name, dispatch }) + Ok(Self { op: op.clone(), device_cst }) } } impl std::hash::Hash for GpuPulsePad { fn hash(&self, state: &mut H) { - self.backend_name.hash(state); self.op.hash(state); } } impl Op for GpuPulsePad { fn name(&self) -> StaticName { - format!("{}PulsePad", self.backend_name).into() + "GpuPulsePad".into() } fn info(&self) -> TractResult> { @@ -210,12 +183,7 @@ impl EvalOp for GpuPulsePad { } fn state(&self, _session: &TurnState, node_id: usize) -> TractResult>> { - Ok(Some(Box::new(GpuPulsePadState { - node_id, - current_pos: 0, - last_valid_frame: None, - dispatch: self.dispatch, - }))) + Ok(Some(Box::new(GpuPulsePadState { node_id, current_pos: 0, last_valid_frame: None }))) } } @@ -237,11 +205,10 @@ struct GpuPulsePadState { node_id: usize, current_pos: usize, last_valid_frame: Option, - dispatch: DispatchCopyNdFn, } fn fill_slice_constant( - dispatch: DispatchCopyNdFn, + ctx: &dyn DeviceContext, dst: &mut DeviceTensor, cst: &DeviceTensor, axis: usize, @@ -251,8 +218,7 @@ fn fill_slice_constant( zone_shape[axis] = range.len(); let mut dst_origin = tvec!(0; dst.rank()); dst_origin[axis] = range.start; - dispatch_copy_with_origins( - dispatch, + ctx.copy_with_origins( &zone_shape, dst, &dst_origin, @@ -264,7 +230,7 @@ fn fill_slice_constant( } fn fill_slice_repeating_one_frame( - dispatch: DispatchCopyNdFn, + ctx: &dyn DeviceContext, dst: &mut DeviceTensor, src: &DeviceTensor, axis: usize, @@ -279,8 +245,7 @@ fn fill_slice_repeating_one_frame( src_origin[axis] = src_frame; let mut src_strides: TVec = src.strides().into(); src_strides[axis] = 0; - dispatch_copy_with_origins( - dispatch, + ctx.copy_with_origins( &zone_shape, dst, &dst_origin, @@ -294,7 +259,7 @@ fn fill_slice_repeating_one_frame( impl GpuPulsePadState { fn save_frame( &mut self, - dispatch: DispatchCopyNdFn, + ctx: &dyn DeviceContext, op: &PulsePad, input: &DeviceTensor, frame: usize, @@ -302,7 +267,7 @@ impl GpuPulsePadState { let mut frame_shape: TVec = input.shape().into(); frame_shape[op.axis] = 1; let last_valid_frame = DeviceTensor::uninitialized_dt(input.datum_type(), &frame_shape)?; - dispatch_assign_slice(dispatch, &last_valid_frame, 0..1, input, frame..frame + 1, op.axis)?; + ctx.assign_slice(&last_valid_frame, 0..1, input, frame..frame + 1, op.axis)?; self.last_valid_frame = Some(last_valid_frame); Ok(()) } @@ -313,8 +278,8 @@ impl GpuPulsePadState { gpu_op: &GpuPulsePad, input: &DeviceTensor, ) -> TractResult { + let ctx = get_context()?; let op = &gpu_op.op; - let dispatch = gpu_op.dispatch; let pulse = input.shape()[op.axis]; let pulse_begin = self.current_pos; let pulse_end = self.current_pos + pulse; @@ -328,14 +293,14 @@ impl GpuPulsePadState { && pulse_begin < end_input { let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1; - self.save_frame(dispatch, op, input, latest_valid_frame)?; + self.save_frame(&*ctx, op, input, latest_valid_frame)?; } // Start with a copy of input let mut output = make_tensor_for_node(session, self.node_id, input.datum_type(), input.shape())?; let flat_len = input.len() * input.datum_type().size_of(); - dispatch_flat_copy(dispatch, input, 0, &output, 0, flat_len)?; + ctx.flat_copy(input, 0, &output, 0, flat_len)?; // Quick return if entirely in valid or invalid range if (pulse_begin >= op.begin_input && pulse_end <= end_input) @@ -349,14 +314,14 @@ impl GpuPulsePadState { let fill_up_to = (op.begin_input - pulse_begin).min(pulse); match &op.mode { PadMode::Constant(_) => fill_slice_constant( - dispatch, + &*ctx, &mut output, gpu_op.device_cst.as_ref().unwrap(), op.axis, 0..fill_up_to, )?, PadMode::Edge => fill_slice_repeating_one_frame( - dispatch, + &*ctx, &mut output, input, op.axis, @@ -366,18 +331,19 @@ impl GpuPulsePadState { _ => unimplemented!(), } } - if pulse_end > end_input && after > 0 { + + if pulse_end > end_input { let fill_from = pulse - (pulse_end - end_input).min(pulse); match &op.mode { PadMode::Constant(_) => fill_slice_constant( - dispatch, + &*ctx, &mut output, gpu_op.device_cst.as_ref().unwrap(), op.axis, fill_from..pulse, )?, PadMode::Edge => fill_slice_repeating_one_frame( - dispatch, + &*ctx, &mut output, self.last_valid_frame.as_ref().unwrap(), op.axis, @@ -387,7 +353,6 @@ impl GpuPulsePadState { _ => unimplemented!(), } } - Ok(output) } } @@ -399,11 +364,12 @@ impl OpState for GpuPulsePadState { op: &dyn Op, inputs: TVec, ) -> TractResult> { - let input = args_1!(inputs).into_tensor(); - let op = op.downcast_ref::().ok_or_else(|| format_err!("Wrong Op type"))?; - let input = input.to_device_tensor()?; - let tensor = self.pad(session, op, input)?; - Ok(tvec!(tensor.into_tensor().into())) + let input = args_1!(inputs); + let gpu_op = + op.downcast_ref::().ok_or_else(|| format_err!("Wrong Op type"))?; + let device_input = input.as_device_tensor().context("Expected a GPU tensor")?; + let output = self.pad(session, gpu_op, device_input)?; + Ok(tvec!(output.into_tensor().into_tvalue())) } } diff --git a/gpu/src/ops/rms_norm.rs b/gpu/src/ops/rms_norm.rs new file mode 100644 index 0000000000..a7fefd51a8 --- /dev/null +++ b/gpu/src/ops/rms_norm.rs @@ -0,0 +1,97 @@ +use crate::tensor::DeviceTensorExt; +use std::sync::Arc; +use tract_core::internal::*; + +use crate::tensor::DeviceTensor; + +pub type DispatchRmsNormFn = fn(&DeviceTensor, usize, &Tensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuRmsNorm { + pub axis: usize, + pub eps: Arc, + pub backend_name: &'static str, + pub dispatch: DispatchRmsNormFn, +} + +impl GpuRmsNorm { + pub fn new( + axis: usize, + eps: Arc, + backend_name: &'static str, + dispatch: DispatchRmsNormFn, + ) -> Self { + Self { axis, eps, backend_name, dispatch } + } +} + +impl std::fmt::Debug for GpuRmsNorm { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}RmsNorm(axis: {:?}, eps: {:?})", self.backend_name, self.axis, self.eps) + } +} + +impl PartialEq for GpuRmsNorm { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.axis == other.axis && self.eps == other.eps + } +} + +impl Eq for GpuRmsNorm {} + +impl std::hash::Hash for GpuRmsNorm { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.axis.hash(state); + self.eps.hash(state); + } +} + +impl Op for GpuRmsNorm { + fn name(&self) -> StaticName { + format!("{}RmsNorm", self.backend_name).into() + } + + fn info(&self) -> TractResult> { + Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)]) + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuRmsNorm { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(input, self.axis, &self.eps, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuRmsNorm { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/rotate_half.rs b/gpu/src/ops/rotate_half.rs new file mode 100644 index 0000000000..80fff4290a --- /dev/null +++ b/gpu/src/ops/rotate_half.rs @@ -0,0 +1,83 @@ +use crate::tensor::DeviceTensorExt; +use tract_core::internal::*; + +use crate::tensor::DeviceTensor; + +pub type DispatchRotateHalfFn = fn(&DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuRotateHalf { + pub backend_name: &'static str, + pub dispatch: DispatchRotateHalfFn, +} + +impl GpuRotateHalf { + pub fn new(backend_name: &'static str, dispatch: DispatchRotateHalfFn) -> Self { + Self { backend_name, dispatch } + } +} + +impl std::fmt::Debug for GpuRotateHalf { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}RotateHalf", self.backend_name) + } +} + +impl PartialEq for GpuRotateHalf { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name + } +} + +impl Eq for GpuRotateHalf {} + +impl std::hash::Hash for GpuRotateHalf { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + } +} + +impl Op for GpuRotateHalf { + fn name(&self) -> StaticName { + format!("{}RotateHalf", self.backend_name).into() + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuRotateHalf { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(input, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuRotateHalf { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} diff --git a/gpu/src/ops/scaled_masked_softmax.rs b/gpu/src/ops/scaled_masked_softmax.rs new file mode 100644 index 0000000000..91b4c109cf --- /dev/null +++ b/gpu/src/ops/scaled_masked_softmax.rs @@ -0,0 +1,83 @@ +use crate::tensor::{DeviceTensor, DeviceTensorExt}; +use tract_core::internal::*; + +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) +/// Only input of rank of 3 is supported +pub type DispatchScaledMaskedSoftmaxFn = + fn(&DeviceTensor, &Tensor, &DeviceTensor, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuScaledMaskedSoftmax { + pub scale: Arc, + pub backend_name: &'static str, + pub dispatch: DispatchScaledMaskedSoftmaxFn, +} + +impl std::fmt::Debug for GpuScaledMaskedSoftmax { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}ScaledMaskedSoftmax", self.backend_name) + } +} + +impl PartialEq for GpuScaledMaskedSoftmax { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.scale == other.scale + } +} +impl Eq for GpuScaledMaskedSoftmax {} + +impl std::hash::Hash for GpuScaledMaskedSoftmax { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.scale.hash(state); + } +} + +impl Op for GpuScaledMaskedSoftmax { + fn name(&self) -> StaticName { + format!("{}ScaledMaskedSoftmax", self.backend_name).into() + } + op_as_typed_op!(); +} + +impl EvalOp for GpuScaledMaskedSoftmax { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let (input_val, mask_val) = args_2!(inputs); + let input = input_val.to_device_tensor()?; + let mask = mask_val.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(input, &self.scale, mask, &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuScaledMaskedSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + ensure!(facts.len() == 2); + let dt = facts[0].datum_type; + ensure!(dt == facts[1].datum_type); + ensure!(facts[0].rank() <= 5); + ensure!(facts[0].rank() >= 2); + ensure!(facts[0].rank() == facts[1].rank()); + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + as_op!(); +} diff --git a/gpu/src/ops/slice.rs b/gpu/src/ops/slice.rs index 23fef4a73c..f0cfd76321 100644 --- a/gpu/src/ops/slice.rs +++ b/gpu/src/ops/slice.rs @@ -1,45 +1,22 @@ use crate::tensor::DeviceTensorExt; -use crate::utils::{DispatchCopyNdFn, compute_broadcast_strides}; +use crate::utils::compute_broadcast_strides; use tract_core::internal::*; use tract_core::ops::array::Slice; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct GpuSlice { pub inner: Slice, - pub backend_name: &'static str, - pub dispatch: DispatchCopyNdFn, } impl GpuSlice { - pub fn new(inner: Slice, backend_name: &'static str, dispatch: DispatchCopyNdFn) -> Self { - Self { inner, backend_name, dispatch } - } -} - -impl std::fmt::Debug for GpuSlice { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}Slice({:?})", self.backend_name, self.inner) - } -} - -impl PartialEq for GpuSlice { - fn eq(&self, other: &Self) -> bool { - self.backend_name == other.backend_name && self.inner == other.inner - } -} - -impl Eq for GpuSlice {} - -impl std::hash::Hash for GpuSlice { - fn hash(&self, state: &mut H) { - self.backend_name.hash(state); - self.inner.hash(state); + pub fn new(inner: Slice) -> Self { + Self { inner } } } impl Op for GpuSlice { fn name(&self) -> StaticName { - format!("{}Slice", self.backend_name).into() + "GpuSlice".into() } fn info(&self) -> TractResult> { @@ -96,7 +73,8 @@ impl EvalOp for GpuSlice { // Slice uses same strides as input (broadcast strides with matching shapes) let broadcast_strides: TVec = compute_broadcast_strides(&o_shape, input_strides)?; - (self.dispatch)( + let ctx = crate::device::get_context()?; + ctx.copy_nd( input, offset, &broadcast_strides, @@ -130,8 +108,6 @@ impl TypedOp for GpuSlice { start: self.inner.start.eval(values), end: self.inner.end.eval(values), }, - backend_name: self.backend_name, - dispatch: self.dispatch, }; let inputs = node.inputs.iter().map(|i| mapping[i]).collect::>(); target.wire_node(&node.name, op, &inputs) diff --git a/gpu/src/ops/softmax.rs b/gpu/src/ops/softmax.rs new file mode 100644 index 0000000000..e8868addca --- /dev/null +++ b/gpu/src/ops/softmax.rs @@ -0,0 +1,141 @@ +use crate::tensor::DeviceTensorExt; +use tract_core::internal::*; +use tract_core::ops::nn as core_ops_nn; + +use crate::tensor::DeviceTensor; + +pub type DispatchSoftmaxFn = fn(&DeviceTensor, usize, &DeviceTensor) -> TractResult<()>; + +#[derive(Clone)] +pub struct GpuSoftmax { + pub axes: TVec, + pub backend_name: &'static str, + pub dispatch: DispatchSoftmaxFn, +} + +impl GpuSoftmax { + pub fn new( + axes: TVec, + backend_name: &'static str, + dispatch: DispatchSoftmaxFn, + ) -> TractResult { + ensure!( + axes.len() == 1, + "Only one axis of softmax is supported by {}Softmax", + backend_name + ); + Ok(Self { axes, backend_name, dispatch }) + } + + pub fn from_tract_core( + core_softmax: &core_ops_nn::Softmax, + backend_name: &'static str, + dispatch: DispatchSoftmaxFn, + ) -> TractResult { + ensure!(core_softmax.quant_output_dt.is_none()); + Self::new(core_softmax.axes.clone(), backend_name, dispatch) + } +} + +impl std::fmt::Debug for GpuSoftmax { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}Softmax(axes: {:?})", self.backend_name, self.axes) + } +} + +impl PartialEq for GpuSoftmax { + fn eq(&self, other: &Self) -> bool { + self.backend_name == other.backend_name && self.axes == other.axes + } +} + +impl Eq for GpuSoftmax {} + +impl std::hash::Hash for GpuSoftmax { + fn hash(&self, state: &mut H) { + self.backend_name.hash(state); + self.axes.hash(state); + } +} + +impl Op for GpuSoftmax { + fn name(&self) -> StaticName { + format!("{}Softmax", self.backend_name).into() + } + + fn info(&self) -> TractResult> { + Ok(vec![format!("axes: {:?}", self.axes)]) + } + + op_as_typed_op!(); +} + +impl EvalOp for GpuSoftmax { + fn is_stateless(&self) -> bool { + true + } + + fn eval_with_session( + &self, + node_id: usize, + session: &TurnState, + inputs: TVec, + ) -> TractResult> { + let input_value = args_1!(inputs); + let input = input_value.to_device_tensor()?; + let output = crate::session_handler::make_tensor_for_node( + session, + node_id, + input.datum_type(), + input.shape(), + )?; + (self.dispatch)(input, self.axes[0], &output)?; + Ok(tvec!(output.into_tensor().into_tvalue())) + } +} + +impl TypedOp for GpuSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::facts_to_device_facts(inputs, |facts| { + let dt = facts[0].datum_type; + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| format!("Error while computing facts for {:?}", self.name())) + } + + fn axes_mapping( + &self, + inputs: &[&TypedFact], + outputs: &[&TypedFact], + ) -> TractResult { + AxesMapping::natural(inputs, outputs) + } + + fn change_axes( + &self, + model: &TypedModel, + node: &TypedNode, + _io: InOut, + change: &AxisOp, + ) -> TractResult> { + let axes: Option> = + self.axes.iter().map(|it| change.transform_axis(*it)).collect(); + if let Some(axes) = axes { + Ok(Some(AxisChangeConsequence::new( + model, + node, + Some(Box::new(GpuSoftmax { + axes, + backend_name: self.backend_name, + dispatch: self.dispatch, + })), + change, + ))) + } else { + Ok(None) + } + } + + as_op!(); +} diff --git a/gpu/src/utils.rs b/gpu/src/utils.rs index af5008d256..c2cf758f9c 100644 --- a/gpu/src/utils.rs +++ b/gpu/src/utils.rs @@ -193,78 +193,6 @@ pub fn reshape_to_rank_3(shape: &[usize], axis: usize) -> TVec { tvec![dim_axis_0, dim_axis_1, dim_axis_2] } -/// Dispatch function for strided copy_nd kernels. All array ops (broadcast, -/// slice, concat, permute_axes) ultimately call a copy_nd kernel with this -/// signature. The backend derives the kernel name from output rank + dtype. -/// Both offsets are in bytes. -pub type DispatchCopyNdFn = fn( - input: &crate::tensor::DeviceTensor, - input_offset: usize, - input_strides: &[isize], - output: &crate::tensor::DeviceTensor, - output_offset: usize, - output_shape: &[usize], - output_strides: &[isize], -) -> TractResult<()>; - -/// Copy a slice along `axis` from `src[src_range]` into `dst[dst_range]`. -/// Both ranges are along the given axis; other dimensions are copied fully. -pub fn dispatch_assign_slice( - dispatch: DispatchCopyNdFn, - dst: &crate::tensor::DeviceTensor, - dst_range: std::ops::Range, - src: &crate::tensor::DeviceTensor, - src_range: std::ops::Range, - axis: usize, -) -> TractResult<()> { - let mut zone_shape: TVec = src.shape().into(); - zone_shape[axis] = src_range.len(); - if zone_shape.iter().product::() == 0 { - return Ok(()); - } - let src_offset = src_range.start * src.strides()[axis] as usize * src.datum_type().size_of(); - let dst_offset = dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of(); - dispatch(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides()) -} - -/// Copy from `src` into `dst` with given origins and strides. -/// Origins are element indices per dimension, converted to byte offsets internally. -pub fn dispatch_copy_with_origins( - dispatch: DispatchCopyNdFn, - zone_shape: &[usize], - dst: &crate::tensor::DeviceTensor, - dst_origin: &[usize], - dst_strides: &[isize], - src: &crate::tensor::DeviceTensor, - src_origin: &[usize], - src_strides: &[isize], -) -> TractResult<()> { - if zone_shape.iter().product::() == 0 { - return Ok(()); - } - let dt_size = src.datum_type().size_of(); - let src_offset: usize = - src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::() * dt_size; - let dst_offset: usize = - dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::() * dt_size; - dispatch(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides) -} - -/// Flat memcpy of `len` bytes from `src` at `src_offset` to `dst` at `dst_offset`. -pub fn dispatch_flat_copy( - dispatch: DispatchCopyNdFn, - src: &crate::tensor::DeviceTensor, - src_byte_offset: usize, - dst: &crate::tensor::DeviceTensor, - dst_byte_offset: usize, - byte_len: usize, -) -> TractResult<()> { - if byte_len == 0 { - return Ok(()); - } - dispatch(src, src_byte_offset, &[1], dst, dst_byte_offset, &[byte_len], &[1]) -} - pub fn check_strides_validity(shape: TVec, strides: TVec) -> TractResult<()> { let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect(); zipped_shape_strides.sort_by_key(|&(_, stride)| stride); diff --git a/metal/Cargo.toml b/metal/Cargo.toml index e0f090f59f..17ffd1b341 100644 --- a/metal/Cargo.toml +++ b/metal/Cargo.toml @@ -27,6 +27,7 @@ metal.workspace = true objc = { version = "0.2.7" } num-traits.workspace = true tract-core.workspace = true +tract-pulse-opl.workspace = true tract-transformers.workspace = true tract-gpu.workspace = true diff --git a/metal/src/context.rs b/metal/src/context.rs index 8b25ed4912..a2cbbcc98b 100644 --- a/metal/src/context.rs +++ b/metal/src/context.rs @@ -258,6 +258,27 @@ impl DeviceContext for MetalContext { bail!("Only BlockQuant Tensor allocation supported for now") } } + + fn copy_nd( + &self, + input: &DeviceTensor, + input_offset: usize, + input_strides: &[isize], + output: &DeviceTensor, + output_offset: usize, + output_shape: &[usize], + output_strides: &[isize], + ) -> TractResult<()> { + crate::kernels::array::metal_copy_nd_dispatch( + input, + input_offset, + input_strides, + output, + output_offset, + output_shape, + output_strides, + ) + } } #[derive(Debug)] diff --git a/metal/src/kernels/array/cast.rs b/metal/src/kernels/array/cast.rs index 93eef6567e..c04ccd0c16 100644 --- a/metal/src/kernels/array/cast.rs +++ b/metal/src/kernels/array/cast.rs @@ -89,3 +89,11 @@ impl Cast { Ok(()) } } + +pub fn metal_cast_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> { + crate::with_metal_stream(|stream| Cast.dispatch_eval(stream, input, output)) +} + +crate::register_metal_op!(tract_core::ops::cast::Cast, |_source, _node, op| { + Ok(crate::transform::metal_cast_new(op.to).map(|c| Box::new(c) as _)) +}); diff --git a/metal/src/kernels/array/mod.rs b/metal/src/kernels/array/mod.rs index 7a2f35c4df..3d2690fc5d 100644 --- a/metal/src/kernels/array/mod.rs +++ b/metal/src/kernels/array/mod.rs @@ -4,9 +4,11 @@ mod dispatch; mod rotate_half; pub use cast::Cast; +pub use cast::metal_cast_dispatch; pub use copy::Memcpy; pub use dispatch::metal_copy_nd_dispatch; pub use rotate_half::RotateHalf; +pub use rotate_half::metal_rotate_half_dispatch; pub fn all_functions() -> Vec { use std::collections::HashSet; diff --git a/metal/src/kernels/array/rotate_half.rs b/metal/src/kernels/array/rotate_half.rs index 726ec4a100..d25ea78165 100644 --- a/metal/src/kernels/array/rotate_half.rs +++ b/metal/src/kernels/array/rotate_half.rs @@ -80,6 +80,18 @@ impl RotateHalf { } } +pub fn metal_rotate_half_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> { + crate::with_metal_stream(|stream| RotateHalf.dispatch_eval(stream, input, output)) +} + +crate::register_metal_op!(tract_transformers::ops::apply_rope::RotateHalf, |source, node, _op| { + rule_if!(RotateHalf::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::rotate_half::GpuRotateHalf::new( + "Metal", + metal_rotate_half_dispatch, + )))) +}); + #[cfg(test)] mod tests { use crate::utils::with_borrowed_metal_stream; diff --git a/metal/src/kernels/bin_ops.metal b/metal/src/kernels/bin_ops.metal index 47cc194672..2d6ebdcfc2 100644 --- a/metal/src/kernels/bin_ops.metal +++ b/metal/src/kernels/bin_ops.metal @@ -311,3 +311,56 @@ INSTANTIATE_BIN_OP(and, And, bool, bool, bool) INSTANTIATE_BIN_OP(or, Or, bool, bool, bool) INSTANTIATE_1ROW_BIN_OP() + +// --- Iff (select) kernel --- + +template +[[kernel]] void iff_generic( + device const bool *cond [[buffer(0)]], + device const T *then_values [[buffer(1)]], + device const T *else_values [[buffer(2)]], + device T *out [[buffer(3)]], + constant const size_t *out_shape [[buffer(4)]], + constant const size_t *cond_strides [[buffer(5)]], + constant const size_t *then_strides [[buffer(6)]], + constant const size_t *else_strides [[buffer(7)]], + constant const size_t *out_strides [[buffer(8)]], + uint tpig [[thread_position_in_grid]]) +{ + size_t total = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] * out_shape[4]; + if (tpig >= total) return; + + size_t tmp = tpig; + size_t i4 = tmp % out_shape[4]; tmp /= out_shape[4]; + size_t i3 = tmp % out_shape[3]; tmp /= out_shape[3]; + size_t i2 = tmp % out_shape[2]; tmp /= out_shape[2]; + size_t i1 = tmp % out_shape[1]; tmp /= out_shape[1]; + size_t i0 = tmp; + + size_t icond = i0 * cond_strides[0] + i1 * cond_strides[1] + i2 * cond_strides[2] + + i3 * cond_strides[3] + i4 * cond_strides[4]; + bool pick = cond[icond]; + + size_t offset = i0 * (pick ? then_strides[0] : else_strides[0]) + + i1 * (pick ? then_strides[1] : else_strides[1]) + + i2 * (pick ? then_strides[2] : else_strides[2]) + + i3 * (pick ? then_strides[3] : else_strides[3]) + + i4 * (pick ? then_strides[4] : else_strides[4]); + + size_t io = i0 * out_strides[0] + i1 * out_strides[1] + i2 * out_strides[2] + + i3 * out_strides[3] + i4 * out_strides[4]; + + out[io] = (pick ? then_values : else_values)[offset]; +} + +#define INSTANTIATE_IFF(tname, type) \ + template [[host_name("bin_ops::iff_generic_" #tname)]] [[kernel]] \ + void iff_generic( \ + device const bool*, device const type*, device const type*, device type*, \ + constant const size_t*, constant const size_t*, constant const size_t*, \ + constant const size_t*, constant const size_t*, uint); + +INSTANTIATE_IFF(u8, uint8_t) +INSTANTIATE_IFF(u16, uint16_t) +INSTANTIATE_IFF(u32, uint32_t) +INSTANTIATE_IFF(u64, uint64_t) diff --git a/metal/src/kernels/bin_ops.rs b/metal/src/kernels/bin_ops.rs index 3395f9be38..ac73be298f 100644 --- a/metal/src/kernels/bin_ops.rs +++ b/metal/src/kernels/bin_ops.rs @@ -32,6 +32,11 @@ pub fn all_functions() -> Vec { }) }) .flatten() + .chain( + ["u8", "u16", "u32", "u64"] + .into_iter() + .map(|tname| format!("bin_ops::iff_generic_{tname}")), + ) .collect() } @@ -219,6 +224,75 @@ pub fn metal_bin_op_dispatch( crate::with_metal_stream(|stream| dispatch_eval(stream, mini_op, lhs, rhs, output)) } +pub fn metal_bin_op(mini_op: Box) -> tract_gpu::ops::binary::GpuBinOp { + tract_gpu::ops::binary::GpuBinOp { + backend_name: "Metal", + mini_op, + dispatch: metal_bin_op_dispatch, + } +} + +crate::register_metal_op!(tract_core::ops::binary::TypedBinOp, |source, node, op| { + rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(metal_bin_op(op.0.clone())))) +}); + +crate::register_metal_op!(tract_core::ops::logic::Iff, |_source, _node, _op| { + Ok(Some(Box::new(tract_gpu::ops::iff::GpuIff { + backend_name: "Metal", + dispatch: metal_iff_dispatch, + }))) +}); + +pub fn metal_iff_dispatch( + cond: &DeviceTensor, + then_value: &DeviceTensor, + else_value: &DeviceTensor, + cond_strides: &[isize], + then_strides: &[isize], + else_strides: &[isize], + output: &DeviceTensor, + output_shape: &[usize], + output_strides: &[isize], +) -> TractResult<()> { + crate::with_metal_stream(|stream| { + stream.retain_tensor(cond); + stream.retain_tensor(then_value); + stream.retain_tensor(else_value); + stream.retain_tensor(output); + + let tname = tract_gpu::utils::BroadcastKind::copy_tname(output.datum_type()); + let kernel_name = format!("bin_ops::iff_generic_{tname}"); + let total_elems: usize = output_shape.iter().product(); + + let pipeline = stream.load_pipeline(LibraryName::BinOps, &kernel_name)?; + let command_buffer = stream.command_buffer(); + + let cond_strides_usize: TVec = cond_strides.iter().map(|&s| s as usize).collect(); + let then_strides_usize: TVec = then_strides.iter().map(|&s| s as usize).collect(); + let else_strides_usize: TVec = else_strides.iter().map(|&s| s as usize).collect(); + let out_strides_usize: TVec = output_strides.iter().map(|&s| s as usize).collect(); + + command_buffer.encode(|encoder| { + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_metal_tensor(0, cond, metal::MTLResourceUsage::Read); + encoder.set_metal_tensor(1, then_value, metal::MTLResourceUsage::Read); + encoder.set_metal_tensor(2, else_value, metal::MTLResourceUsage::Read); + encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); + encoder.set_slice(4, output_shape); + encoder.set_slice(5, &cond_strides_usize); + encoder.set_slice(6, &then_strides_usize); + encoder.set_slice(7, &else_strides_usize); + encoder.set_slice(8, &out_strides_usize); + + let grid_size = MTLSize { width: total_elems as NSUInteger, height: 1, depth: 1 }; + let group_size = MTLSize { width: 1, height: 1, depth: 1 }; + encoder.dispatch_thread_groups(grid_size, group_size); + }); + Ok(()) + }) +} + #[cfg(test)] mod tests { use crate::utils::with_borrowed_metal_stream; diff --git a/metal/src/kernels/element_wise.rs b/metal/src/kernels/element_wise.rs index dc945cddf4..6485467577 100644 --- a/metal/src/kernels/element_wise.rs +++ b/metal/src/kernels/element_wise.rs @@ -91,3 +91,19 @@ pub fn metal_element_wise_dispatch( ) -> TractResult<()> { crate::with_metal_stream(|stream| dispatch_eval(stream, mini_op, input, output)) } + +pub fn metal_element_wise_op( + mini_op: Box, +) -> tract_gpu::ops::element_wise::GpuElementWise { + tract_gpu::ops::element_wise::GpuElementWise { + backend_name: "Metal", + mini_op, + dispatch: metal_element_wise_dispatch, + } +} + +// Generic element-wise fallback — checked after LeakyRelu, GeluApproximate. +crate::register_metal_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| { + rule_if!(is_supported(&*op.0, source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(metal_element_wise_op(op.0.clone())))) +}); diff --git a/metal/src/kernels/nn/apply_rope.rs b/metal/src/kernels/nn/apply_rope.rs index 6993b09c48..1e94112ea1 100644 --- a/metal/src/kernels/nn/apply_rope.rs +++ b/metal/src/kernels/nn/apply_rope.rs @@ -112,6 +112,23 @@ impl ApplyRope { } } +pub fn metal_apply_rope_dispatch( + input: &DeviceTensor, + cos: &DeviceTensor, + sin: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| ApplyRope.dispatch_eval(stream, input, cos, sin, output)) +} + +crate::register_metal_op!(tract_transformers::ops::apply_rope::ApplyRope, |source, node, _op| { + rule_if!(ApplyRope::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::apply_rope::GpuApplyRope { + backend_name: "Metal", + dispatch: metal_apply_rope_dispatch, + }))) +}); + #[cfg(test)] mod tests { use super::*; diff --git a/metal/src/kernels/nn/gelu_approximate.rs b/metal/src/kernels/nn/gelu_approximate.rs index b32445031b..a5790e0299 100644 --- a/metal/src/kernels/nn/gelu_approximate.rs +++ b/metal/src/kernels/nn/gelu_approximate.rs @@ -67,6 +67,29 @@ impl GeluApproximate { } } +pub fn metal_gelu_approximate_dispatch( + fast_impl: bool, + input: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| { + GeluApproximate { fast_impl }.dispatch_eval(stream, input, output) + }) +} + +// GeluApproximate is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId. +crate::register_metal_op!(tract_core::ops::element_wise::ElementWiseOp, |source, node, op| { + rule_if_some!( + ew = op.0.downcast_ref::() + ); + rule_if!(GeluApproximate::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::gelu_approximate::GpuGeluApproximate::new( + ew.fast_impl, + "Metal", + metal_gelu_approximate_dispatch, + )))) +}); + #[cfg(test)] mod tests { use super::*; diff --git a/metal/src/kernels/nn/leaky_relu.rs b/metal/src/kernels/nn/leaky_relu.rs index 11479b53dc..da06dd7de8 100644 --- a/metal/src/kernels/nn/leaky_relu.rs +++ b/metal/src/kernels/nn/leaky_relu.rs @@ -52,3 +52,21 @@ impl LeakyRelu { Ok(()) } } + +pub fn metal_leaky_relu_dispatch( + alpha: f32, + input: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| LeakyRelu.dispatch_eval(stream, input, alpha, output)) +} + +// LeakyRelu is an ElementWiseMiniOp, so we register under ElementWiseOp's TypeId. +crate::register_metal_op!(tract_core::ops::element_wise::ElementWiseOp, |_source, _node, op| { + rule_if_some!(leaky = op.0.downcast_ref::()); + Ok(Some(Box::new(tract_gpu::ops::leaky_relu::GpuLeakyRelu::new( + leaky.alpha, + "Metal", + metal_leaky_relu_dispatch, + )))) +}); diff --git a/metal/src/kernels/nn/mod.rs b/metal/src/kernels/nn/mod.rs index 7d64b499a7..270dd7a674 100644 --- a/metal/src/kernels/nn/mod.rs +++ b/metal/src/kernels/nn/mod.rs @@ -7,14 +7,18 @@ pub mod scaled_masked_softmax; pub mod silu; pub mod softmax; -pub use apply_rope::ApplyRope; +pub use apply_rope::{ApplyRope, metal_apply_rope_dispatch}; pub use gelu_approximate::GeluApproximate; +pub use gelu_approximate::metal_gelu_approximate_dispatch; pub use leaky_relu::LeakyRelu; +pub use leaky_relu::metal_leaky_relu_dispatch; pub use reduce::{Reducer, metal_reduce_launch}; pub use rms_norm::RmsNorm; -pub use scaled_masked_softmax::ScaledMaskedSoftmax; +pub use rms_norm::metal_rms_norm_dispatch; +pub use scaled_masked_softmax::{ScaledMaskedSoftmax, metal_scaled_masked_softmax_dispatch}; pub use silu::Silu; pub use softmax::Softmax; +pub use softmax::metal_softmax_dispatch; use crate::kernels::BroadcastKind; diff --git a/metal/src/kernels/nn/reduce.rs b/metal/src/kernels/nn/reduce.rs index 3ad97b0bdd..0098cbd8cf 100644 --- a/metal/src/kernels/nn/reduce.rs +++ b/metal/src/kernels/nn/reduce.rs @@ -52,6 +52,17 @@ pub fn metal_reduce_launch( }) } +crate::register_metal_op!(tract_core::ops::nn::Reduce, |source, node, op| { + let dt = source.node_input_facts(node.id)?[0].datum_type; + if let Ok(gpu_op) = + tract_gpu::ops::reduce::GpuReduce::from_tract_core(op, "Metal", metal_reduce_launch) + { + rule_if!(gpu_op.reducer.is_supported_dt(dt)); + return Ok(Some(Box::new(gpu_op))); + } + Ok(None) +}); + #[cfg(test)] mod tests { use super::*; diff --git a/metal/src/kernels/nn/rms_norm.rs b/metal/src/kernels/nn/rms_norm.rs index dd080b223a..1180997a27 100644 --- a/metal/src/kernels/nn/rms_norm.rs +++ b/metal/src/kernels/nn/rms_norm.rs @@ -130,6 +130,25 @@ impl RmsNorm { } } +pub fn metal_rms_norm_dispatch( + input: &DeviceTensor, + axis: usize, + eps: &Tensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| RmsNorm.dispatch_eval(stream, input, axis, eps, output)) +} + +crate::register_metal_op!(tract_transformers::ops::rms_norm::RmsNorm, |source, node, op| { + rule_if!(RmsNorm::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::rms_norm::GpuRmsNorm::new( + op.axis, + op.eps.clone(), + "Metal", + metal_rms_norm_dispatch, + )))) +}); + #[cfg(test)] mod tests { use crate::utils::with_borrowed_metal_stream; diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs index 805dbc4218..7874760fb7 100644 --- a/metal/src/kernels/nn/scaled_masked_softmax.rs +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -107,6 +107,32 @@ fn pad(vals: &[impl AsPrimitive], neutral: isize) -> [isize; 5] { it } +pub fn metal_scaled_masked_softmax_dispatch( + input: &DeviceTensor, + scale: &Tensor, + mask: &DeviceTensor, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| { + ScaledMaskedSoftmax.dispatch_eval(stream, input, scale, mask, output) + }) +} + +crate::register_metal_op!( + tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax, + |source, node, op| { + rule_if!(!op.post_softmax_mask); + rule_if!(ScaledMaskedSoftmax::is_supported_dt( + source.node_input_facts(node.id)?[0].datum_type + )); + Ok(Some(Box::new(tract_gpu::ops::scaled_masked_softmax::GpuScaledMaskedSoftmax { + scale: op.scale.clone(), + backend_name: "Metal", + dispatch: metal_scaled_masked_softmax_dispatch, + }))) + } +); + #[cfg(test)] mod tests { use crate::utils::with_borrowed_metal_stream; diff --git a/metal/src/kernels/nn/softmax.rs b/metal/src/kernels/nn/softmax.rs index e7b2ebc2f3..29efc7d55d 100644 --- a/metal/src/kernels/nn/softmax.rs +++ b/metal/src/kernels/nn/softmax.rs @@ -68,6 +68,23 @@ impl Softmax { } } +pub fn metal_softmax_dispatch( + input: &DeviceTensor, + axis: usize, + output: &DeviceTensor, +) -> TractResult<()> { + crate::with_metal_stream(|stream| Softmax.dispatch_eval(stream, input, axis, output)) +} + +crate::register_metal_op!(tract_core::ops::nn::Softmax, |source, node, op| { + rule_if!(Softmax::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type)); + Ok(Some(Box::new(tract_gpu::ops::softmax::GpuSoftmax::from_tract_core( + op, + "Metal", + metal_softmax_dispatch, + )?))) +}); + #[cfg(test)] mod tests { use super::*; diff --git a/metal/src/ops/apply_rope.rs b/metal/src/ops/apply_rope.rs deleted file mode 100644 index 058dfd44a8..0000000000 --- a/metal/src/ops/apply_rope.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::kernels::nn::ApplyRope; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct MetalApplyRope; - -impl Op for MetalApplyRope { - fn name(&self) -> StaticName { - "MetalApplyRope".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalApplyRope { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - let (raw_input, raw_cos, raw_sin) = args_3!(inputs); - let input = raw_input.to_device_tensor()?; - let cos = raw_cos.to_device_tensor()?; - let sin = raw_sin.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - - crate::with_metal_stream(|stream| { - ApplyRope.dispatch_eval(stream, input, cos, sin, &output) - })?; - Ok(tvec!(output.into_tensor().into_tvalue())) - } -} - -impl TypedOp for MetalApplyRope { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/cast.rs b/metal/src/ops/cast.rs deleted file mode 100644 index 7972d56a0d..0000000000 --- a/metal/src/ops/cast.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::kernels; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct MetalCast { - pub to: DatumType, -} - -impl MetalCast { - pub fn is_supported_dt(dt: DatumType) -> bool { - kernels::array::Cast::is_supported_dt(dt) - } - - pub fn new(to: DatumType) -> Option { - Self::is_supported_dt(to).then_some(Self { to }) - } -} - -impl Op for MetalCast { - fn name(&self) -> StaticName { - "MetalCast".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalCast { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - if input.datum_type() == self.to { - Ok(tvec!(input_value)) - } else { - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - self.to, - input.shape(), - )?; - crate::with_metal_stream(|stream| { - kernels::array::Cast.dispatch_eval(stream, input, &output) - })?; - Ok(tvec![output.into_tensor().into_tvalue()]) - } - } -} - -impl TypedOp for MetalCast { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - Ok(tvec!(self.to.fact(facts[0].shape.clone()))) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/gelu_approximate.rs b/metal/src/ops/gelu_approximate.rs deleted file mode 100644 index dfb4c8b6fc..0000000000 --- a/metal/src/ops/gelu_approximate.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::kernels::nn::GeluApproximate; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)] -pub struct MetalGeluApproximate { - pub fast_impl: bool, -} - -impl Op for MetalGeluApproximate { - fn name(&self) -> StaticName { - "MetalGeluApproximate".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalGeluApproximate { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let input = args_1!(inputs); - let input_metal = input.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input_metal.datum_type(), - input_metal.shape(), - )?; - GeluApproximate { fast_impl: self.fast_impl }.dispatch_eval( - stream, - input_metal, - &output, - )?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalGeluApproximate { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/leaky_relu.rs b/metal/src/ops/leaky_relu.rs deleted file mode 100644 index b27c04fa63..0000000000 --- a/metal/src/ops/leaky_relu.rs +++ /dev/null @@ -1,57 +0,0 @@ -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -use crate::kernels::nn::LeakyRelu; - -#[derive(Debug, Clone, Default, PartialEq)] -pub struct MetalLeakyRelu { - pub alpha: f32, -} -impl Eq for MetalLeakyRelu {} - -impl Op for MetalLeakyRelu { - fn name(&self) -> StaticName { - "MetalLeakyRelu".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalLeakyRelu { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let input = args_1!(inputs); - let input_metal = input.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input_metal.datum_type(), - input_metal.shape(), - )?; - LeakyRelu.dispatch_eval(stream, input_metal, self.alpha, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalLeakyRelu { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/mod.rs b/metal/src/ops/mod.rs index a1cfa79326..4c0defe450 100644 --- a/metal/src/ops/mod.rs +++ b/metal/src/ops/mod.rs @@ -1,22 +1,6 @@ -pub mod apply_rope; -pub mod cast; pub mod conv; pub mod fused_axis_op; -pub mod gelu_approximate; pub mod gemm; -pub mod leaky_relu; -pub mod rms_norm; -pub mod rotate_half; -pub mod scaled_masked_softmax; -pub mod softmax; -pub use apply_rope::MetalApplyRope; -pub use cast::MetalCast; pub use fused_axis_op::MetalFusedAxisOp; -pub use gelu_approximate::MetalGeluApproximate; pub use gemm::MetalGemm; -pub use leaky_relu::MetalLeakyRelu; -pub use rms_norm::MetalRmsNorm; -pub use rotate_half::MetalRotateHalf; -pub use scaled_masked_softmax::MetalScaledMaskedSoftmax; -pub use softmax::MetalSoftmax; diff --git a/metal/src/ops/rms_norm.rs b/metal/src/ops/rms_norm.rs deleted file mode 100644 index 0fc4ce9961..0000000000 --- a/metal/src/ops/rms_norm.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::kernels::nn::RmsNorm; -use derive_new::new; -use std::sync::Arc; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct MetalRmsNorm { - pub axis: usize, - pub eps: Arc, -} - -impl Op for MetalRmsNorm { - fn name(&self) -> StaticName { - "MetalRmsNorm".into() - } - fn info(&self) -> TractResult> { - Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)]) - } - op_as_typed_op!(); -} - -impl EvalOp for MetalRmsNorm { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - RmsNorm.dispatch_eval(stream, input, self.axis, &self.eps, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalRmsNorm { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/rotate_half.rs b/metal/src/ops/rotate_half.rs deleted file mode 100644 index 562c607f79..0000000000 --- a/metal/src/ops/rotate_half.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::kernels::array::RotateHalf; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct MetalRotateHalf; - -impl Op for MetalRotateHalf { - fn name(&self) -> StaticName { - "MetalRotateHalf".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalRotateHalf { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - RotateHalf.dispatch_eval(stream, input, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalRotateHalf { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/scaled_masked_softmax.rs b/metal/src/ops/scaled_masked_softmax.rs deleted file mode 100644 index 661f816280..0000000000 --- a/metal/src/ops/scaled_masked_softmax.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::kernels::nn::ScaledMaskedSoftmax; -use derive_new::new; -use tract_core::internal::*; -use tract_gpu::tensor::DeviceTensorExt; - -/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) -/// Only input of rank of 3 is supported -#[derive(Clone, Debug, new, Hash, PartialEq, Eq)] -pub struct MetalScaledMaskedSoftmax { - pub scale: Arc, -} - -impl Op for MetalScaledMaskedSoftmax { - fn name(&self) -> StaticName { - "MetalScaledMaskedSoftmax".into() - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalScaledMaskedSoftmax { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let (raw_input, raw_mask) = args_2!(inputs); - let input = raw_input.to_device_tensor()?; - let mask = raw_mask.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - ScaledMaskedSoftmax.dispatch_eval(stream, input, &self.scale, mask, &output)?; - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalScaledMaskedSoftmax { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - ensure!(facts.len() == 2); - let dt = facts[0].datum_type; - ensure!(dt == facts[1].datum_type); - ensure!(facts[0].rank() <= 5); - ensure!(facts[0].rank() >= 2); - ensure!(facts[0].rank() == facts[1].rank()); - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - as_op!(); -} diff --git a/metal/src/ops/softmax.rs b/metal/src/ops/softmax.rs deleted file mode 100644 index 5cd4bcb59f..0000000000 --- a/metal/src/ops/softmax.rs +++ /dev/null @@ -1,103 +0,0 @@ -use crate::kernels::nn::Softmax; -use std::fmt::Debug; -use tract_core::internal::*; -use tract_core::ops::nn as core_ops_nn; -use tract_gpu::tensor::DeviceTensorExt; - -#[derive(Debug, Clone, Hash, Default, PartialEq, Eq)] -pub struct MetalSoftmax { - pub axes: TVec, -} - -impl MetalSoftmax { - pub fn new(axes: TVec) -> TractResult { - ensure!(axes.len() == 1, "Only one axis of softmax is supported by MetalSoftmax"); - Ok(Self { axes }) - } - - pub fn from_tract_core(core_softmax: &core_ops_nn::Softmax) -> TractResult { - ensure!(core_softmax.quant_output_dt.is_none()); - Self::new(core_softmax.axes.clone()) - } -} - -impl Op for MetalSoftmax { - fn name(&self) -> StaticName { - "MetalSoftmax".into() - } - - fn info(&self) -> TractResult> { - Ok(vec![format!("axes: {:?}", self.axes)]) - } - - op_as_typed_op!(); -} - -impl EvalOp for MetalSoftmax { - fn is_stateless(&self) -> bool { - true - } - - fn eval_with_session( - &self, - node_id: usize, - session: &TurnState, - inputs: TVec, - ) -> TractResult> { - crate::with_metal_stream(|stream| { - let input_value = args_1!(inputs); - let input = input_value.to_device_tensor()?; - let output = tract_gpu::session_handler::make_tensor_for_node( - session, - node_id, - input.datum_type(), - input.shape(), - )?; - Softmax.dispatch_eval(stream, input, self.axes[0], &output)?; - - Ok(tvec!(output.into_tensor().into_tvalue())) - }) - } -} - -impl TypedOp for MetalSoftmax { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - tract_gpu::utils::facts_to_device_facts(inputs, |facts| { - let dt = facts[0].datum_type; - let fact = dt.fact(facts[0].shape.clone()); - Ok(tvec!(fact)) - }) - .with_context(|| format!("Error while computing facts for {:?}", self.name())) - } - - fn axes_mapping( - &self, - inputs: &[&TypedFact], - outputs: &[&TypedFact], - ) -> TractResult { - AxesMapping::natural(inputs, outputs) - } - - fn change_axes( - &self, - model: &TypedModel, - node: &TypedNode, - _io: InOut, - change: &AxisOp, - ) -> TractResult> { - let axes: Option> = - self.axes.iter().map(|it| change.transform_axis(*it)).collect(); - if let Some(axes) = axes { - Ok(Some(AxisChangeConsequence::new( - model, - node, - Some(Box::new(MetalSoftmax { axes })), - change, - ))) - } else { - Ok(None) - } - } - - as_op!(); -} diff --git a/metal/src/rewrite_rules/fuse_axis_op.rs b/metal/src/rewrite_rules/fuse_axis_op.rs index 73e26d1966..465b69928d 100644 --- a/metal/src/rewrite_rules/fuse_axis_op.rs +++ b/metal/src/rewrite_rules/fuse_axis_op.rs @@ -12,8 +12,8 @@ fn is_supported_axis_op(op: &GpuAxisOp) -> bool { fn can_fuse_move(model: &TypedModel, axis_node: &TypedNode) -> bool { model.single_succ(axis_node.id).unwrap().is_some_and(|node| { node.op_is::() - || node.op_is::() - || node.op_is::() + || node.op_is::() + || node.op_is::() || node.op_is::() || node.op_is::() || node.op_is::() @@ -189,12 +189,7 @@ pub fn fuse_move_axis( } // Reshape are always fusable. Change Move by Reshape if possible - let simpl_op = GpuAxisOp::simplify_axis_op( - axis_op.inner.clone(), - in_shape.dims(), - axis_op.backend_name, - axis_op.dispatch, - ); + let simpl_op = GpuAxisOp::simplify_axis_op(axis_op.inner.clone(), in_shape.dims()); if simpl_op != *axis_op { return Ok(Some(TypedModelPatch::replace_single_op( model, @@ -221,7 +216,7 @@ pub fn fuse_move_axis( let inputs = patch.taps(model, &axis_node.inputs)?; let out = patch.wire_node( format!("{axis_node_name}.fused_move_axis"), - GpuAxisOp::new(new_axis_ops[0].clone(), axis_op.backend_name, axis_op.dispatch), + GpuAxisOp::new(new_axis_ops[0].clone()), &inputs, )?; patch.shunt_outside(model, cursor.id.into(), out[0])?; @@ -238,11 +233,8 @@ pub fn fuse_move_axis( if ax == from_1 { let mut patch = TypedModelPatch::default(); let inputs = patch.taps(model, &cursor.inputs)?; - let out = patch.wire_node( - cursor.name.clone(), - GpuAxisOp::new(AxisOp::Add(to_1), axis_op.backend_name, axis_op.dispatch), - &inputs, - )?; + let out = + patch.wire_node(cursor.name.clone(), GpuAxisOp::new(AxisOp::Add(to_1)), &inputs)?; patch.shunt_outside(model, axis_node.id.into(), out[0])?; return Ok(Some(patch)); } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index ea03553773..4aea5ed9b6 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -1,41 +1,57 @@ +use std::any::TypeId; +use std::collections::HashMap; +use std::fmt::Debug; +use std::str::FromStr; +use std::sync::OnceLock; + use crate::context::metal_context; use crate::kernels::matmul::{GemmKernel, GgmlGemm, MetalGemmImplKind, MfaGemm, MlxGemm}; -use crate::kernels::nn::metal_reduce_launch; use crate::{kernels, ops}; -use tract_core::tract_linalg::block_quant::Q4_0; -use tract_gpu::fact::DeviceTypedFactExt; -use tract_gpu::ops::reduce::GpuReduce; -use tract_gpu::rewrite_rules::rewire_sdpa::rewire_sdpa; -use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs; -use tract_gpu::rewrite_rules::rms_norm::remove_rms_norm_cast; -use tract_gpu::sync::{DeviceSyncKind, sync_inputs_if_required, sync_model_outputs_if_required}; -use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache; - -use crate::rewrite_rules; -use std::fmt::Debug; -use std::str::FromStr; use tract_core::dyn_clone::clone_box; use tract_core::internal::translator::Translate; use tract_core::internal::*; -use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat}; -use tract_core::ops::binary::{BinMiniOp, TypedBinOp}; -use tract_core::ops::cast::Cast; -use tract_core::ops::cnn::conv::rewrite_kernel_conv_in_oihw; -use tract_core::ops::cnn::{Conv, rewrite_conv_with_n_axis}; use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul}; -use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::konst::Const; -use tract_core::ops::nn::{LeakyRelu, Reduce, Softmax as CoreSoftmax}; +use tract_core::tract_linalg::block_quant::Q4_0; use tract_core::transform::ModelTransform; -use tract_gpu::fact::DeviceFact; -use tract_gpu::tensor::DeviceTensor; -use tract_gpu::tensor::IntoDevice; +use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt}; +use tract_gpu::rewrite_rules::rewire_sdpa::rewire_sdpa; +use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs; +use tract_gpu::rewrite_rules::rms_norm::remove_rms_norm_cast; +use tract_gpu::sync::{DeviceSyncKind, sync_inputs_if_required, sync_model_outputs_if_required}; +use tract_gpu::tensor::{DeviceTensor, IntoDevice}; use tract_gpu::utils::as_quant_fact; -use tract_itertools::Itertools; -use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf}; -use tract_transformers::ops::gelu_approximate::GeluApproximate; -use tract_transformers::ops::rms_norm::RmsNorm; -use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax; + +use crate::rewrite_rules; + +/// A registered translator that can convert a core op into a Metal GPU op. +/// Each kernel module submits one (or more) of these via [`register_metal_op!`]. +pub struct MetalOpTranslator { + pub type_id: TypeId, + pub try_make: fn(&TypedModel, &TypedNode) -> TractResult>>, +} + +inventory::collect!(MetalOpTranslator); + +/// Register a translator for a core op type. The closure receives `(source, node, op)` +/// where `op` is already downcast to `$op_type`. Return `Ok(Some(gpu_op))` to translate, +/// `Ok(None)` to skip. +#[macro_export] +macro_rules! register_metal_op { + ($op_type:ty, |$source:ident, $node:ident, $op:ident| $body:expr) => { + inventory::submit! { + $crate::transform::MetalOpTranslator { + type_id: std::any::TypeId::of::<$op_type>(), + try_make: |$source, $node| { + let Some($op) = $node.op_as::<$op_type>() else { + return Ok(None); + }; + $body + }, + } + } + }; +} impl MetalGemmImplKind { pub fn variants() -> Vec { @@ -105,8 +121,6 @@ impl MetalTransform { .rewrite(self, model)?; Rewriter::default() - .with_rule_for("rewrite_kernel_conv_in_oihw", rewrite_kernel_conv_in_oihw) - .with_rule_for("rewrite_conv_with_n_axis", rewrite_conv_with_n_axis) .with_rule_for("remove_rms_norm_cast", remove_rms_norm_cast) .rewrite(&(), model)?; @@ -132,67 +146,40 @@ impl MetalTransform { } } -fn can_translate_to_metal_op(source: &TypedModel, node: &TypedNode) -> TractResult { - let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec(); - let input_dts = input_facts - .iter() - .map(|f| f.as_device_fact().map(|f| f.datum_type).unwrap_or(f.datum_type)) - .collect_vec(); - - let in_dts_metal_compatible = - input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type)); - - Ok(in_dts_metal_compatible - && (node - .op_as::() - .is_some_and(|op| crate::kernels::element_wise::is_supported(&*op.0, input_dts[0])) - || node - .op_as::() - .is_some_and(|op| crate::kernels::bin_ops::is_supported(&*op.0, input_dts[0])) - || node.op_is::() - || node.op_as::().is_some_and(|op| { - !op.transpose_c && op.quantize_output.is_none() && check_matmul_in_dts(&input_facts) - }) - || (node.op_is::() - && matches!(input_facts[0].datum_type, DatumType::F16 | DatumType::F32)) - || node - .op_as::() - .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type())) - || node.op_as::().is_some_and(|op| { - ops::MetalCast::is_supported_dt(input_dts[0]) - && ops::MetalCast::new(op.to).is_some() - }) - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_is::() - || node.op_as::().is_some_and(|op| { - GpuReduce::from_tract_core(op, "Metal", metal_reduce_launch) - .is_ok_and(|op| op.reducer.is_supported_dt(input_dts[0])) - }) - || node.op_as::().is_some_and(|op| { - kernels::nn::Softmax::is_supported_dt(input_dts[0]) - && ops::MetalSoftmax::from_tract_core(op).is_ok() - }) - || node - .op_as::() - .is_some_and(|_| kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0])) - || node - .op_as::() - .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0])) - || node - .op_as::() - .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0])) - || node - .op_as::() - .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0])) - || node.op_as::().is_some_and(|op| { - op.0.is::() - && kernels::nn::GeluApproximate::is_supported_dt(input_dts[0]) - }) - || node.op_as::().is_some_and(|op| { - op.0.is::() && kernels::nn::LeakyRelu::is_supported_dt(input_dts[0]) - }))) +/// Looks up the node's op TypeId in the inventory of registered `MetalOpTranslator`s. +/// Returns `Some(gpu_op)` if a translator matches and succeeds, `None` otherwise. +fn try_make_metal_op( + source: &TypedModel, + node: &TypedNode, +) -> TractResult>> { + type TranslateFn = fn(&TypedModel, &TypedNode) -> TractResult>>; + static MAP: OnceLock>> = OnceLock::new(); + let map = MAP.get_or_init(|| { + let mut m: HashMap> = HashMap::new(); + for t in inventory::iter:: { + m.entry(t.type_id).or_default().push(t.try_make); + } + m + }); + + let input_facts = source.node_input_facts(node.id)?; + if !input_facts.iter().all(|f| DeviceTensor::is_supported_dt(f.datum_type)) { + return Ok(None); + } + + // Copy-based ops are fully generic (no backend-specific dispatch needed). + if let Some(op) = tract_gpu::ops::copy_based::try_make_copy_based_op(source, node)? { + return Ok(Some(op)); + } + + if let Some(fns) = map.get(&(*node.op).type_id()) { + for f in fns { + if let Some(op) = f(source, node)? { + return Ok(Some(op)); + } + } + } + Ok(None) } impl Translate, TypedFact, Box> for MetalTransform { @@ -203,89 +190,40 @@ impl Translate, TypedFact, Box> for Met target: &mut TypedModel, mapping: &HashMap, ) -> TractResult> { - let translatable = can_translate_to_metal_op(source, node)?; - - if translatable { - let mut device_inputs = - sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; - - let outlet_ids: TVec = if let Some(op) = node.op_as::() { - convert_matmul_to_metal( + // Special multi-node ops handled first + let input_facts = source.node_input_facts(node.id)?; + if let Some(op) = node.op_as::() { + let facts: Vec = input_facts.iter().map(|f| (*f).clone()).collect(); + if !op.transpose_c && op.quantize_output.is_none() && check_matmul_in_dts(&facts) { + let mut device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = convert_matmul_to_metal( source, node, target, &mut device_inputs, op, self.gemm_impl, - )? - } else if let Some(conv) = node.op_as::() { - ops::conv::wire_metal_conv(source, node, target, &device_inputs, conv)? - } else { - let op: Box = if let Some(op) = node.op_as::() { - if let Some(ew) = op.0.downcast_ref::() { - Box::new(ops::MetalGeluApproximate { fast_impl: ew.fast_impl }) - } else if let Some(leaky) = op.0.downcast_ref::() { - Box::new(ops::MetalLeakyRelu { alpha: leaky.alpha }) - } else { - Box::new(metal_element_wise_op(op.0.clone())) - } - } else if let Some(op) = node.op_as::() { - Box::new(metal_bin_op(op.0.clone())) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::broadcast::GpuMultiBroadcastTo::new( - op.shape.clone(), - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(convert_const(op)?) - } else if let Some(op) = node.op_as::() { - Box::new(ops::MetalCast::new(op.to).unwrap()) - } else if let Some(op) = node.op_as::() { - let in_fact = source.node_input_facts(node.id)?[0]; - Box::new(tract_gpu::ops::change_axes::GpuAxisOp::from_tract_core_with_fact( - op.clone(), - in_fact, - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::slice::GpuSlice::new( - op.clone(), - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::concat::GpuConcat::new( - op.axis, - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - )) - } else if let Some(op) = node.op_as::() { - Box::new(GpuReduce::from_tract_core(op, "Metal", metal_reduce_launch).unwrap()) - } else if let Some(op) = node.op_as::() { - Box::new(ops::MetalSoftmax::from_tract_core(op).unwrap()) - } else if let Some(op) = node.op_as::() - && !op.post_softmax_mask - { - Box::new(ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) - } else if let Some(op) = node.op_as::() { - Box::new(ops::MetalRmsNorm::new(op.axis, op.eps.clone())) - } else if let Some(_op) = node.op_as::() { - Box::new(ops::MetalRotateHalf) - } else if let Some(_op) = node.op_as::() { - Box::new(ops::MetalApplyRope) - } else if let Some(op) = node.op_as::() { - Box::new(tract_gpu::ops::dyn_kv_cache::GpuDynKVCache::from_tract_transformers( - op, - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - )) - } else { - bail!("Failed to translate a supported Metal Op") - }; - target.wire_node(node.name.clone(), op, &device_inputs)? - }; + )?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + } + // Const: inline conversion, not a GPU op + if let Some(op) = node.op_as::() { + if DeviceTensor::is_supported_dt(op.val().datum_type()) { + let device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = + target.wire_node(node.name.clone(), convert_const(op)?, &device_inputs)?; + return sync_model_outputs_if_required(source, node, target, outlet_ids); + } + } + + // Single-op translation + if let Some(gpu_op) = try_make_metal_op(source, node)? { + let device_inputs = + sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?; + let outlet_ids = target.wire_node(node.name.clone(), gpu_op, &device_inputs)?; sync_model_outputs_if_required(source, node, target, outlet_ids) } else { let cpu_inputs = @@ -295,25 +233,13 @@ impl Translate, TypedFact, Box> for Met } } -use tract_gpu::ops::binary::GpuBinOp; - -fn metal_bin_op(mini_op: Box) -> GpuBinOp { - GpuBinOp { - backend_name: "Metal", - mini_op, - dispatch: crate::kernels::bin_ops::metal_bin_op_dispatch, - } -} - -use tract_core::ops::element_wise::ElementWiseMiniOp; -use tract_gpu::ops::element_wise::GpuElementWise; - -fn metal_element_wise_op(mini_op: Box) -> GpuElementWise { - GpuElementWise { - backend_name: "Metal", - mini_op, - dispatch: crate::kernels::element_wise::metal_element_wise_dispatch, - } +pub(crate) fn metal_cast_new(to: DatumType) -> Option { + tract_gpu::ops::cast::GpuCast::new( + to, + "Metal", + kernels::array::metal_cast_dispatch, + kernels::array::Cast::is_supported_dt, + ) } fn check_matmul_in_dts(in_facts: &[TypedFact]) -> bool { @@ -386,7 +312,7 @@ fn convert_matmul_to_metal( }; *inp_to_cast = target.wire_node( node.name.clone() + ".cast_input", - ops::MetalCast::new(DatumType::F32).unwrap(), + metal_cast_new(DatumType::F32).unwrap(), &[*inp_to_cast], )?[0]; } @@ -419,17 +345,14 @@ fn convert_matmul_to_metal( ); let rank = input_facts[a_pos].rank(); - let perm_a_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - ); + let perm_a_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); let perm_a_name = node.name.clone() + ".perm_a"; inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0]; } if input_facts[0].datum_type == DatumType::F16 { - let in_cast_op = ops::MetalCast::new(DatumType::F32).unwrap(); + let in_cast_op = metal_cast_new(DatumType::F32).unwrap(); inputs[0] = target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[inputs[0]])?[0]; } @@ -441,11 +364,8 @@ fn convert_matmul_to_metal( ); let rank = input_facts[b_pos].rank(); - let perm_b_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - ); + let perm_b_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); let perm_b_name = node.name.clone() + ".perm_b"; inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0]; } @@ -460,11 +380,8 @@ fn convert_matmul_to_metal( .map(|fact| fact.clarify_dt_shape().unwrap().1.len()) .unwrap(); - let perm_out_op = tract_gpu::ops::change_axes::GpuAxisOp::new( - AxisOp::Move(rank - 2, rank - 1), - "Metal", - crate::kernels::array::metal_copy_nd_dispatch, - ); + let perm_out_op = + tract_gpu::ops::change_axes::GpuAxisOp::new(AxisOp::Move(rank - 2, rank - 1)); matmul_output = target.wire_node( node.name.clone() + ".perm_out", perm_out_op, @@ -482,10 +399,10 @@ fn convert_matmul_to_metal( if out_dt != expected_dt { ensure!( - ops::MetalCast::is_supported_dt(out_dt), + kernels::array::Cast::is_supported_dt(out_dt), "Matmul output type cannot be casted to expected type" ); - let cast_op = ops::MetalCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); + let cast_op = metal_cast_new(model.node_output_facts(node.id)?[0].datum_type).unwrap(); matmul_output = target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)? }