Enable ROCm Triton backend for AllReduce#684
Enable ROCm Triton backend for AllReduce#684mfrancepillois wants to merge 9 commits intorocm-jaxlib-v0.9.1from
Conversation
08ee4a7 to
1daa0d9
Compare
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
Claude Code Review SummaryThis PR enables the ROCm Triton backend for AllReduce collective operations via Key issues found (see inline comments for details):
Missing test coverage for |
Re-review SummaryRe-reviewed the latest diff. All 8 previously flagged issues remain unaddressed — no new findings. Key open items:
See existing inline comments for details. 🤖 Generated with Claude Code |
a866b69 to
e885196
Compare
Re-review SummaryExcellent progress — 7 of 8 previous findings have been addressed in this revision:
1 minor item remaining: Overall this is looking solid. The two-stage lowering design is clean and the ROCm atomics implementation is well-structured. 🤖 Generated with Claude Code |
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_lower_extern_atomics.mlir
Outdated
Show resolved
Hide resolved
Re-review SummaryGood progress since the last review — 6 of 8 previous findings have been addressed (mask handling, validation, syncscope, signed/unsigned comparison, scope parsing). The include ordering issue in New findings (see inline comments):
🤖 Generated with Claude Code |
Re-review SummaryRe-reviewed the latest diff. All previously flagged issues remain applicable — no new findings. Key outstanding items from prior review:
🤖 Generated with Claude Code |
Re-review Summary (commit 90ec42c)All previously flagged issues have been addressed in this revision. Key fixes verified:
No new issues found. One minor prior note (include ordering of 🤖 Generated with Claude Code |
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_get_tid_pass.cc
Show resolved
Hide resolved
| #if defined(TENSORFLOW_USE_ROCM) | ||
| // ROCm: Use constant value directly as ROCDL dialect doesn't define memory | ||
| // space enum | ||
| static constexpr int32_t kGlobalAddressSpace = 1; |
There was a problem hiding this comment.
https://mlir.llvm.org/docs/Dialects/GPU/#gpu-address-spaces should be same for both.
| bool is_supported = false; | ||
|
|
||
| // CUDA: Requires compute capability 9.0+ (Hopper or newer) | ||
| if (device_info.cuda_compute_capability().major >= 9) { |
There was a problem hiding this comment.
Where did this check exist before this change?
There was a problem hiding this comment.
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
|
|
||
| // Atomic block: perform atomic exchange | ||
| builder.setInsertionPointToStart(atomic_block); | ||
| auto atomic_xchg = LLVM::AtomicRMWOp::create( |
There was a problem hiding this comment.
Use atomic store https://mlir.llvm.org/docs/Dialects/LLVM/#llvmstore-llvmstoreop. This is not optimal. Use https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirpoison-llvmpoisonop for result.
There was a problem hiding this comment.
We expect result not to be used. If it is it is a bug.
| // Loop block: spin wait | ||
| builder.setInsertionPointToStart(loop_block); | ||
| auto loaded = LLVM::LoadOp::create( | ||
| builder, loc, i32_type, addr, 4, false, false, false, false, |
There was a problem hiding this comment.
Is it atomic load? It is hard to follow. Maybe inline comment (/* arg_name */ false) each of the bool args.
There was a problem hiding this comment.
NVM. I see it from the test, but still comment.
| mlir::ValueRange{exit_block->getArgument(0)}); | ||
| call_op.erase(); | ||
| } else { | ||
| // Unmasked spin wait: direct loop |
There was a problem hiding this comment.
Can you unify these two paths. Maybe via lambda that you can call at both places.
|
|
||
| // Clean up unused extern function declarations | ||
| llvm::SmallVector<LLVM::LLVMFuncOp> to_erase; | ||
| module.walk([&](LLVM::LLVMFuncOp func) { |
| // Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>] | ||
| std::string ParseSyncScope(const std::string& func_name) { | ||
| // Per AMDGPU memory model (Table 31): | ||
| // - "" (empty) = system scope (cross-device visibility) |
| // CHECK-NOT: llvm.call @xla_get_thread_id | ||
| // CHECK: [[TID:%.*]] = llvm.call_intrinsic "llvm.amdgcn.workitem.id.x"() : () -> i32 | ||
| // CHECK: llvm.return [[TID]] | ||
| %tid = llvm.call @xla_get_thread_id() : () -> i32 |
There was a problem hiding this comment.
maybe prefix them with __triton_xla
| auto value = operands[1]; | ||
| mlir::Value mask = operands.size() > 2 ? operands[2] : mlir::Value{}; | ||
|
|
||
| std::string syncscope = ParseSyncScope(callee_name); |
| // Exit block: phi node to select result | ||
| exit_block->addArgument(i32_type, loc); | ||
| call_op.replaceAllUsesWith( | ||
| mlir::ValueRange{exit_block->getArgument(0)}); |
There was a problem hiding this comment.
Is the result expeced to be used. If not poison it.
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_implement_extern_atomics_rocm.mlir
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Show resolved
Hide resolved
Re-review Summary (Claude)All 15 previously flagged issues have been addressed — nice work on the iteration. New findings this round:
The critical item (#1) should be resolved before merge — either fix the implementation to use |
| @@ -0,0 +1,276 @@ | |||
| /* Copyright 2025 The OpenXLA Authors. | |||
There was a problem hiding this comment.
Ok. This is reasonable conservative approach. Have you check does the cuda actually need it. Does generated ptx differs that much if going trough llvm ir?
There was a problem hiding this comment.
No I don't. I struggled a bit to test on CUDA target. But is it really up to us to modify the CUDA specific compilation passes? I assume they had probably a reason for preferring to use inline PTX rather than LLVM intrinsics when they implemented this pass in the first place.
There was a problem hiding this comment.
Yes. Because it is easy for them to use PTX. And every time one does a cuda/rocm split one side bitrots. Ours. Try asking Chao for nv machine to check on.
There was a problem hiding this comment.
In my view, it's simpler to maintain one pass dedicated exclusively to ROCm and another dedicated to CUDA rather than a single pass (especially since the Triton pipeline changes from one target to another anyway). But anyway, I updated the support and compared the generated PTX, which are similar:
1. Thread ID Retrieval (Line ~67 in both files)
Old (PTX Assembly):
// begin inline asm
mov.u32 %r8, %tid.x;
// end inline asm
New (Intrinsics):
mov.u32 %r2, %tid.x;
2. Atomic Store (Line ~75-80)
Old (PTX Assembly):
// begin inline asm
st.global.sys.release.u32 [%rd25], %r12;
// end inline asm
New (Intrinsics):
st.release.sys.global.b32 [%rd30], %r4;
3. Atomic Spin-Wait Loop (Line ~85-95)
Old (PTX Assembly):
// begin inline asm
{
.reg .pred %p<1>;
.reg .b32 %r<1>;
wait:
ld.global.sys.acquire.u32 %r0, [%rd28];
setp.lt.u32 %p0, %r0, %r12;
@%p0 bra wait;
}
// end inline asm
New (Intrinsics):
$L__BB0_2:
ld.acquire.sys.global.b32 %r14, [%rd7];
setp.lt.u32 %p2, %r14, %r4;
@%p2 bra $L__BB0_2;
37396a3 to
0cc93da
Compare
| if (func_name.contains("_system")) { | ||
| return ""; // System scope for cross-GPU visibility | ||
| } else if (func_name.contains("_gpu")) { | ||
| return "gpu"; |
There was a problem hiding this comment.
You sure these ones are correct. I would expect "device" and "block". But you are right. This is cumbersome. Sorry for driving you around. Maybe sick to per arch pass, but have this be a common logic that both can call pasing in scope names and and threadIDx.x intrinsic name?
There was a problem hiding this comment.
Scope names have been changes (gpu and cta are indeed aliases to device and block).
But for the second point, I'm not sure to follow what you would like to have? 2 different passes using intrinsics?
There was a problem hiding this comment.
@draganmladjenovic Maybe we could open a PR upstream with this support (common pass) and see what google/nvidia guys say about it?
0cc93da to
e5310e5
Compare
📝 Summary of Changes
This PR enables the ROCm Triton backend for AllReduce (collective emitter).
To this end:
triton_xla.atomic_writeandtriton_xla.atomic_spin_waitandtriton_xla.get_tid. These passes rely onextern_elementwisetriton operations. thereby avoiding the use of target specific inline assembly. Theextern_elementwiseops are then caught later in the compilater pipeline and replaced by llvm intrinsics.RocmExecutor::CanEnablePeerAccessTo(int other_device_ordinal)(this API is required to enable collective_emitter thunk).🎯 Justification
Prior to this PR, the
triton_xla.get_tid,triton_xla.atomic_writeandtriton_xla.atomic_spin_waitoperations were lowered using PTX assembly. Therefore, AllReduce triton backend was only available for CUDA target.This PR adds a new way to lower these operations only Triton operations using
extern_elementwise.Thanks to that, the triton backend for AllReduce is now available for ROCm target.
🚀 Kind of Contribution
Please remove what does not apply: ✨ New Feature
🧪 Unit Tests:
This PR includes a LIT test checking the lowering of atomic operations.