Factorize 6 more ops: cast, rotate_half, softmax, rms_norm, gelu_appr…#2098
Open
Factorize 6 more ops: cast, rotate_half, softmax, rms_norm, gelu_appr…#2098
Conversation
e17c777 to
fb21b50
Compare
…oximate, leaky_relu All 6 were byte-for-byte identical between CUDA and Metal except for the stream call. Each gets a shared GpuXxx struct in gpu/src/ops/ with a backend dispatch fn pointer. - cast: DispatchCastFn(input, output) - rotate_half: DispatchRotateHalfFn(input, output) - softmax: DispatchSoftmaxFn(input, axis, output) - rms_norm: DispatchRmsNormFn(input, axis, eps, output) - gelu_approximate: DispatchGeluApproximateFn(fast_impl, input, output) - leaky_relu: DispatchLeakyReluFn(alpha, input, output) Deletes 12 backend op files, replaces with 6 shared files.
Both were identical between backends. GpuApplyRope takes 3 input tensors (input, cos, sin). GpuScaledMaskedSoftmax takes input + mask with a scale parameter.
Replace the dual op enumeration (can_translate_to_*_op check + translate_node if/else chain) with a single try_make_*_op function that validates and constructs the GPU op in one pass. This eliminates the maintenance burden of keeping two lists in sync and fixes a latent bug where ScaledMaskedSoftmax's post_softmax_mask check was missing from can_translate in both backends.
Replace the monolithic if-chain in try_make_*_op with a TypeId→TranslateFn HashMap built from inventory at startup. Each kernel file registers its own translator via register_cuda_op!/register_metal_op! macros right next to the dispatch function. Adding a new GPU op no longer requires touching transform.rs.
Add copy_nd as a required method on DeviceContext, with assign_slice, copy_with_origins, and flat_copy as default methods built on top of it. Remove DispatchCopyNdFn and backend_name from all 7 copy-based ops (MultiBroadcastTo, AxisOp, Slice, TypedConcat, DynKeyValueCache, Delay, PulsePad) — they are now fully generic in the gpu crate with zero backend-specific code. A shared try_make_copy_based_op constructs them.
The typedef-based approach failed because decltype(iff_generic<float>) captures float* buffer types, which don't match uint8_t/uint16_t/etc instantiations. Spell out the full signature in the macro instead.
609c83c to
f09823d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…oximate, leaky_relu
All 6 were byte-for-byte identical between CUDA and Metal except for the stream call. Each gets a shared GpuXxx struct in gpu/src/ops/ with a backend dispatch fn pointer.
Deletes 12 backend op files, replaces with 6 shared files.