Skip to content

Enable ROCm Triton backend for AllReduce#684

Open
mfrancepillois wants to merge 9 commits intorocm-jaxlib-v0.9.1from
ci_maxime_allreduce_triton_rocm_elementwise_rocm
Open

Enable ROCm Triton backend for AllReduce#684
mfrancepillois wants to merge 9 commits intorocm-jaxlib-v0.9.1from
ci_maxime_allreduce_triton_rocm_elementwise_rocm

Conversation

@mfrancepillois
Copy link

📝 Summary of Changes
This PR enables the ROCm Triton backend for AllReduce (collective emitter).
To this end:

  • Add a new passes to lower atomic operations triton_xla.atomic_write and triton_xla.atomic_spin_wait and triton_xla.get_tid. These passes rely on extern_elementwise triton operations. thereby avoiding the use of target specific inline assembly. The extern_elementwise ops are then caught later in the compilater pipeline and replaced by llvm intrinsics.
  • Add missing API to rocm_executor RocmExecutor::CanEnablePeerAccessTo(int other_device_ordinal) (this API is required to enable collective_emitter thunk).

🎯 Justification
Prior to this PR, the triton_xla.get_tid, triton_xla.atomic_write and triton_xla.atomic_spin_wait operations were lowered using PTX assembly. Therefore, AllReduce triton backend was only available for CUDA target.
This PR adds a new way to lower these operations only Triton operations using extern_elementwise.
Thanks to that, the triton backend for AllReduce is now available for ROCm target.

🚀 Kind of Contribution
Please remove what does not apply: ✨ New Feature

🧪 Unit Tests:
This PR includes a LIT test checking the lowering of atomic operations.

@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch from 08ee4a7 to 1daa0d9 Compare March 18, 2026 14:51
@claude
Copy link

claude bot commented Mar 18, 2026

Claude Code Review Summary

This PR enables the ROCm Triton backend for AllReduce collective operations via tt.extern_elementwise ops and a ROCm-specific LLVM IR implementation pass. The overall architecture — two-stage lowering (high-level ops → extern calls → platform-specific LLVM atomics) — is clean and extensible.

Key issues found (see inline comments for details):

  • Masks silently ignored in both LowerAtomicWriteOp and LowerAtomicSpinWaitOp — the CUDA pass handles masks via predicated stores/waits, but the extern lowering discards them entirely. The spin wait case can cause GPU hangs.
  • Signed vs unsigned comparison mismatch — the ROCm spin wait uses slt (signed) while the CUDA PTX uses setp.lt.u32 (unsigned), which can produce incorrect results for large signal values.
  • Syncscope "one-as" for system scope may not provide cross-device visibility guarantees on AMDGPU — needs verification against the AMDGPU memory model docs.
  • No validation of memory semantics — the CUDA pass rejects unsupported semantics (e.g., ACQUIRE for writes); this pass silently accepts them.
  • default: return "unknown" in enum-to-string switches will generate unmatchable function names instead of failing early.
  • Minor: duplicated scope-parsing logic, fragile substring matching, misplaced #include.

Missing test coverage for TritonXLAImplementExternAtomicsROCmPass and TritonXLALowerExternGetTidPass.

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Re-reviewed the latest diff. All 8 previously flagged issues remain unaddressed — no new findings. Key open items:

  • Masks silently discarded in both LowerAtomicWriteOp and LowerAtomicSpinWaitOp (risk of GPU hang / memory corruption)
  • Signed vs unsigned comparison mismatch with CUDA path in spin-wait
  • Incorrect AMDGPU syncscopes ("one-as" alone likely insufficient for system scope)
  • No semantic validation and hardcoded atomic orderings ignoring function-name-encoded semantics
  • Fragile substring matching and duplicated scope-parsing logic
  • Missing test coverage for TritonXLAImplementExternAtomicsROCmPass and TritonXLALowerExternGetTidPass
  • Include ordering issue in rocm_executor.cc

See existing inline comments for details.

🤖 Generated with Claude Code

@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch from a866b69 to e885196 Compare March 18, 2026 16:00
@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Excellent progress — 7 of 8 previous findings have been addressed in this revision:

  • Mask handling: Both atomic_write and atomic_spin_wait now properly pass and handle masks
  • Exhaustive switches: default cases replaced with exhaustive switch + LOG(FATAL)
  • Semantic validation: Unsupported semantics now rejected via notifyMatchFailure
  • Sync scope: System scope corrected to empty string per AMDGPU convention
  • Unsigned comparison: Spin-wait comparator now uses ICmpPredicate::ult matching CUDA
  • Helper extraction: ParseSyncScope extracted; EndsWithComparator used for suffix matching

1 minor item remaining: #include <unistd.h> placement in rocm_executor.cc (between ROCm and XLA headers — should be in the system headers group).

Overall this is looking solid. The two-stage lowering design is clean and the ROCm atomics implementation is well-structured.

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Good progress since the last review — 6 of 8 previous findings have been addressed (mask handling, validation, syncscope, signed/unsigned comparison, scope parsing). The include ordering issue in rocm_executor.cc is still present.

New findings (see inline comments):

  • Bug (high): SSA domination issue in masked and unmasked spin-wait — replaceAllUsesWith references a value not defined on all control-flow paths (triton_xla_implement_extern_atomics_rocm_pass.cc:213)
  • Bug (high): Scalar atomic write path silently discards the mask operand (triton_xla_lower_extern_atomics_pass.cc:162-169)
  • Nit: Memory ordering hardcoded to release/acquire even when relaxed was requested — safe but suboptimal
  • Nit: Substring matching in ParseSyncScope is fragile
  • Nit: Mask detection by operand count without documentation
  • Testing gap: No tests for masked operations, ROCm implementation pass, or GetTid pass

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Re-reviewed the latest diff. All previously flagged issues remain applicable — no new findings. Key outstanding items from prior review:

  • SSA dominance bug in spin-wait loop (triton_xla_implement_extern_atomics_rocm_pass.cc:213,248) — loaded value defined inside loop block used after exit
  • Hardcoded release ordering ignoring _relaxed_ semantic in function names
  • Missing test coverage for the ROCm-specific atomics pass and the get_tid pass

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary (commit 90ec42c)

All previously flagged issues have been addressed in this revision. Key fixes verified:

  • Mask handling: Scalar and vectorized mask paths now handled correctly with proper operand count checks
  • Memory ordering: Parsed from function names instead of hardcoded
  • SSA domination: Fixed via block arguments
  • Sync scope parsing: Uses robust field extraction
  • Unsigned comparison: Spin-wait loop now uses unsigned icmp as required
  • Test coverage: Comprehensive tests added for masked/unmasked operations in both lowering and ROCm implementation passes

No new issues found. One minor prior note (include ordering of <unistd.h> in rocm_executor.cc) remains outstanding from the initial review.

🤖 Generated with Claude Code

#if defined(TENSORFLOW_USE_ROCM)
// ROCm: Use constant value directly as ROCDL dialect doesn't define memory
// space enum
static constexpr int32_t kGlobalAddressSpace = 1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool is_supported = false;

// CUDA: Requires compute capability 9.0+ (Hopper or newer)
if (device_info.cuda_compute_capability().major >= 9) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this check exist before this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// Atomic block: perform atomic exchange
builder.setInsertionPointToStart(atomic_block);
auto atomic_xchg = LLVM::AtomicRMWOp::create(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect result not to be used. If it is it is a bug.

// Loop block: spin wait
builder.setInsertionPointToStart(loop_block);
auto loaded = LLVM::LoadOp::create(
builder, loc, i32_type, addr, 4, false, false, false, false,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it atomic load? It is hard to follow. Maybe inline comment (/* arg_name */ false) each of the bool args.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM. I see it from the test, but still comment.

mlir::ValueRange{exit_block->getArgument(0)});
call_op.erase();
} else {
// Unmasked spin wait: direct loop

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you unify these two paths. Maybe via lambda that you can call at both places.


// Clean up unused extern function declarations
llvm::SmallVector<LLVM::LLVMFuncOp> to_erase;
module.walk([&](LLVM::LLVMFuncOp func) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice touch!

// Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>]
std::string ParseSyncScope(const std::string& func_name) {
// Per AMDGPU memory model (Table 31):
// - "" (empty) = system scope (cross-device visibility)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for nvptx backend.

// CHECK-NOT: llvm.call @xla_get_thread_id
// CHECK: [[TID:%.*]] = llvm.call_intrinsic "llvm.amdgcn.workitem.id.x"() : () -> i32
// CHECK: llvm.return [[TID]]
%tid = llvm.call @xla_get_thread_id() : () -> i32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe prefix them with __triton_xla

auto value = operands[1];
mlir::Value mask = operands.size() > 2 ? operands[2] : mlir::Value{};

std::string syncscope = ParseSyncScope(callee_name);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StringRefs

// Exit block: phi node to select result
exit_block->addArgument(i32_type, loc);
call_op.replaceAllUsesWith(
mlir::ValueRange{exit_block->getArgument(0)});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the result expeced to be used. If not poison it.

@i-chaochen i-chaochen added the claude-review Request a Claude AI code review for this PR label Mar 24, 2026
@claude
Copy link

claude bot commented Mar 24, 2026

Re-review Summary (Claude)

All 15 previously flagged issues have been addressed — nice work on the iteration.

New findings this round:

  1. [Critical] ROCm atomic write implementation uses LLVM::StoreOp but tests expect llvm.atomicrmw xchg — these are different instructions and the masked path also lacks block arguments for PHI merging. The LIT tests will fail as written. (inline comment)

  2. [Suggestion] Stale comment in triton_xla_lower_extern_atomics_pass.cc still says "ROCm-specific" but the lowering is now target-agnostic. (inline comment)

  3. [Suggestion] Missing LIT test for the CUDA implementation pass (triton_xla_implement_extern_atomics_cuda_pass.cc). (inline comment)

The critical item (#1) should be resolved before merge — either fix the implementation to use AtomicRMWOp or update the tests.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 24, 2026
@@ -0,0 +1,276 @@
/* Copyright 2025 The OpenXLA Authors.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. This is reasonable conservative approach. Have you check does the cuda actually need it. Does generated ptx differs that much if going trough llvm ir?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I don't. I struggled a bit to test on CUDA target. But is it really up to us to modify the CUDA specific compilation passes? I assume they had probably a reason for preferring to use inline PTX rather than LLVM intrinsics when they implemented this pass in the first place.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Because it is easy for them to use PTX. And every time one does a cuda/rocm split one side bitrots. Ours. Try asking Chao for nv machine to check on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my view, it's simpler to maintain one pass dedicated exclusively to ROCm and another dedicated to CUDA rather than a single pass (especially since the Triton pipeline changes from one target to another anyway). But anyway, I updated the support and compared the generated PTX, which are similar:

1. Thread ID Retrieval (Line ~67 in both files)

Old (PTX Assembly):

// begin inline asm

    mov.u32 %r8, %tid.x;
  
// end inline asm

New (Intrinsics):

mov.u32 	%r2, %tid.x;

2. Atomic Store (Line ~75-80)

Old (PTX Assembly):

// begin inline asm

    st.global.sys.release.u32 [%rd25], %r12;
  
// end inline asm

New (Intrinsics):

st.release.sys.global.b32 	[%rd30], %r4;

3. Atomic Spin-Wait Loop (Line ~85-95)

Old (PTX Assembly):

// begin inline asm

    {
    .reg .pred %p<1>;
    .reg .b32 %r<1>;
    wait:
      ld.global.sys.acquire.u32 %r0, [%rd28];
      setp.lt.u32 %p0, %r0, %r12;
      @%p0 bra wait;
    }
  
// end inline asm

New (Intrinsics):

$L__BB0_2:
	ld.acquire.sys.global.b32 	%r14, [%rd7];
	setp.lt.u32 	%p2, %r14, %r4;
	@%p2 bra 	$L__BB0_2;

@mfrancepillois mfrancepillois marked this pull request as ready for review March 24, 2026 12:22
@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch 3 times, most recently from 37396a3 to 0cc93da Compare March 25, 2026 17:44
if (func_name.contains("_system")) {
return ""; // System scope for cross-GPU visibility
} else if (func_name.contains("_gpu")) {
return "gpu";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You sure these ones are correct. I would expect "device" and "block". But you are right. This is cumbersome. Sorry for driving you around. Maybe sick to per arch pass, but have this be a common logic that both can call pasing in scope names and and threadIDx.x intrinsic name?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope names have been changes (gpu and cta are indeed aliases to device and block).
But for the second point, I'm not sure to follow what you would like to have? 2 different passes using intrinsics?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@draganmladjenovic Maybe we could open a PR upstream with this support (common pass) and see what google/nvidia guys say about it?

@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch from 0cc93da to e5310e5 Compare March 26, 2026 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants