Skip to content

Add hipCUB support for scan (prefix sum) op#668

Open
magaonka-amd wants to merge 1 commit intoROCm:mainfrom
magaonka-amd:feature/hipcub-scan-support-downstream
Open

Add hipCUB support for scan (prefix sum) op#668
magaonka-amd wants to merge 1 commit intoROCm:mainfrom
magaonka-amd:feature/hipcub-scan-support-downstream

Conversation

@magaonka-amd
Copy link

  • Add hipCUB DeviceScan::InclusiveSum kernel implementation mirroring the CUDA cub_scan_kernel from commit 75f8001
  • Register ROCm FFI handler for xla.gpu.ext.cub_scan with runtime type dispatch for all 12 supported types
  • Add parameterized kernel test covering all types, row/column sizes, and scan configurations

Submission Checklist

@magaonka-amd magaonka-amd marked this pull request as ready for review March 16, 2026 22:10
Comment on lines +53 to +58
for (int64_t col = 0; col < column_length; ++col) {
TF_RETURN_IF_ERROR(
stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum(
d_temp_storage, temp_bytes, d_in + col * row_length,
d_out + col * row_length, row_length, stream)));
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: per-column host loop issues N separate kernel launches

This loop launches one hipcub::DeviceScan::InclusiveSum per column, which means column_length separate GPU kernel launches. For large column_length values, the launch overhead will dominate.

The CUDA counterpart avoids this by using a custom BlockScanKernel with gridDim = column_length — all columns are processed in a single kernel launch, one block per column (see cub_scan_kernel_cuda_impl.cu.cc:122).

Consider implementing a similar custom hipCUB BlockScan-based kernel for the ROCm path to achieve comparable performance, or at minimum documenting this known performance gap with a TODO.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@magaonka-amd have you checked hipcub that we need BlockLoad?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, I see they have recently changed from host-side InclusiveSum to device-side function called from a custom BlockScanKernel: openxla#39238

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I'll explore this today and get back with the update.

Comment on lines +48 to +51
if (d_in == nullptr) {
return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum(
d_temp_storage, temp_bytes, d_in, d_out, row_length, stream));
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: query-mode sentinel checks d_in instead of d_temp_storage

The hipCUB convention for "query scratch size" is d_temp_storage == nullptr. This code uses d_in == nullptr as the sentinel instead. It works today because CubScanGetScratchSize (line 123) passes both as nullptr, but checking d_temp_storage == nullptr would be more idiomatic and robust against future callers that might pass a valid d_temp_storage with null d_in.

Suggested change
if (d_in == nullptr) {
return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum(
d_temp_storage, temp_bytes, d_in, d_out, row_length, stream));
}
if (d_temp_storage == nullptr) {
return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum(
d_temp_storage, temp_bytes, d_in, d_out, row_length, stream));
}

Comment on lines +91 to +93
GTEST_SKIP() << "BF16 for row length > 128 has precision issues.",
absl::OkStatus();
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: GTEST_SKIP() comma-expression in non-void function

GTEST_SKIP() expands to a return from the current scope in Google Test. The comma-expression GTEST_SKIP() << "...", absl::OkStatus() works because GTEST_SKIP() returns before the comma-operand is evaluated, but this relies on an implementation detail of the macro. A more explicit pattern would be:

if (type == xla::PrimitiveType::BF16 && row_length > 128) {
  GTEST_SKIP() << "BF16 for row length > 128 has precision issues.";
  return absl::OkStatus();  // unreachable, satisfies return type
}

This is inherited from the CUDA test, so feel free to leave as-is for consistency if preferred.

@claude
Copy link

claude bot commented Mar 16, 2026

Review Summary

Good implementation of hipCUB-based InclusiveSum for the ROCm scan op, with clean FFI registration and thorough parameterized testing across 12 data types.

Main concern: The per-column host loop in the impl issues one hipcub::DeviceScan::InclusiveSum launch per column, whereas the CUDA counterpart uses a custom BlockScanKernel that processes all columns in a single launch (gridDim = column_length). This may cause a significant performance gap for multi-column scans.

Two minor nits posted inline (sentinel check, test skip pattern).

🤖 Generated with Claude Code

@magaonka-amd magaonka-amd force-pushed the feature/hipcub-scan-support-downstream branch from 316f257 to 015864a Compare March 18, 2026 16:09
- Add custom BlockScanKernel using rocPRIM block-level primitives
  (block_scan, block_load, block_store) for efficient batched scan
- Single kernel launch with gridDim = column_length, one block per row
- Register ROCm FFI handler for xla.gpu.ext.cub_scan with runtime type
  dispatch for all 12 supported types
- Tuning via rocPRIM default_scan_config_base (architecture-aware)
- Add parameterized kernel test and performance benchmarks
@magaonka-amd magaonka-amd force-pushed the feature/hipcub-scan-support-downstream branch from 015864a to 85b82bb Compare March 18, 2026 23:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants