Skip to content

Factorize 6 more ops: cast, rotate_half, softmax, rms_norm, gelu_appr…#2098

Open
kali wants to merge 9 commits intomainfrom
gpu-more-ops-factorize
Open

Factorize 6 more ops: cast, rotate_half, softmax, rms_norm, gelu_appr…#2098
kali wants to merge 9 commits intomainfrom
gpu-more-ops-factorize

Conversation

@kali
Copy link
Copy Markdown
Collaborator

@kali kali commented Apr 2, 2026

…oximate, leaky_relu

All 6 were byte-for-byte identical between CUDA and Metal except for the stream call. Each gets a shared GpuXxx struct in gpu/src/ops/ with a backend dispatch fn pointer.

  • cast: DispatchCastFn(input, output)
  • rotate_half: DispatchRotateHalfFn(input, output)
  • softmax: DispatchSoftmaxFn(input, axis, output)
  • rms_norm: DispatchRmsNormFn(input, axis, eps, output)
  • gelu_approximate: DispatchGeluApproximateFn(fast_impl, input, output)
  • leaky_relu: DispatchLeakyReluFn(alpha, input, output)

Deletes 12 backend op files, replaces with 6 shared files.

@kali kali force-pushed the gpu-more-ops-factorize branch from e17c777 to fb21b50 Compare April 3, 2026 06:27
kali added 9 commits April 3, 2026 16:05
…oximate, leaky_relu

All 6 were byte-for-byte identical between CUDA and Metal except for
the stream call. Each gets a shared GpuXxx struct in gpu/src/ops/ with
a backend dispatch fn pointer.

- cast: DispatchCastFn(input, output)
- rotate_half: DispatchRotateHalfFn(input, output)
- softmax: DispatchSoftmaxFn(input, axis, output)
- rms_norm: DispatchRmsNormFn(input, axis, eps, output)
- gelu_approximate: DispatchGeluApproximateFn(fast_impl, input, output)
- leaky_relu: DispatchLeakyReluFn(alpha, input, output)

Deletes 12 backend op files, replaces with 6 shared files.
Both were identical between backends. GpuApplyRope takes 3 input
tensors (input, cos, sin). GpuScaledMaskedSoftmax takes input + mask
with a scale parameter.
Replace the dual op enumeration (can_translate_to_*_op check + translate_node
if/else chain) with a single try_make_*_op function that validates and constructs
the GPU op in one pass. This eliminates the maintenance burden of keeping two
lists in sync and fixes a latent bug where ScaledMaskedSoftmax's post_softmax_mask
check was missing from can_translate in both backends.
Replace the monolithic if-chain in try_make_*_op with a TypeId→TranslateFn
HashMap built from inventory at startup. Each kernel file registers its own
translator via register_cuda_op!/register_metal_op! macros right next to the
dispatch function. Adding a new GPU op no longer requires touching transform.rs.
Add copy_nd as a required method on DeviceContext, with assign_slice,
copy_with_origins, and flat_copy as default methods built on top of it.
Remove DispatchCopyNdFn and backend_name from all 7 copy-based ops
(MultiBroadcastTo, AxisOp, Slice, TypedConcat, DynKeyValueCache, Delay,
PulsePad) — they are now fully generic in the gpu crate with zero
backend-specific code. A shared try_make_copy_based_op constructs them.
The typedef-based approach failed because decltype(iff_generic<float>)
captures float* buffer types, which don't match uint8_t/uint16_t/etc
instantiations. Spell out the full signature in the macro instead.
@kali kali force-pushed the gpu-more-ops-factorize branch from 609c83c to f09823d Compare April 3, 2026 16:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant