Add hipCUB support for scan (prefix sum) op#668
Add hipCUB support for scan (prefix sum) op#668magaonka-amd wants to merge 1 commit intoROCm:mainfrom
Conversation
| for (int64_t col = 0; col < column_length; ++col) { | ||
| TF_RETURN_IF_ERROR( | ||
| stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum( | ||
| d_temp_storage, temp_bytes, d_in + col * row_length, | ||
| d_out + col * row_length, row_length, stream))); | ||
| } |
There was a problem hiding this comment.
Performance: per-column host loop issues N separate kernel launches
This loop launches one hipcub::DeviceScan::InclusiveSum per column, which means column_length separate GPU kernel launches. For large column_length values, the launch overhead will dominate.
The CUDA counterpart avoids this by using a custom BlockScanKernel with gridDim = column_length — all columns are processed in a single kernel launch, one block per column (see cub_scan_kernel_cuda_impl.cu.cc:122).
Consider implementing a similar custom hipCUB BlockScan-based kernel for the ROCm path to achieve comparable performance, or at minimum documenting this known performance gap with a TODO.
There was a problem hiding this comment.
@magaonka-amd have you checked hipcub that we need BlockLoad?
There was a problem hiding this comment.
Yes, I agree, I see they have recently changed from host-side InclusiveSum to device-side function called from a custom BlockScanKernel: openxla#39238
There was a problem hiding this comment.
Thanks for the feedback, I'll explore this today and get back with the update.
| if (d_in == nullptr) { | ||
| return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum( | ||
| d_temp_storage, temp_bytes, d_in, d_out, row_length, stream)); | ||
| } |
There was a problem hiding this comment.
Nit: query-mode sentinel checks d_in instead of d_temp_storage
The hipCUB convention for "query scratch size" is d_temp_storage == nullptr. This code uses d_in == nullptr as the sentinel instead. It works today because CubScanGetScratchSize (line 123) passes both as nullptr, but checking d_temp_storage == nullptr would be more idiomatic and robust against future callers that might pass a valid d_temp_storage with null d_in.
| if (d_in == nullptr) { | |
| return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum( | |
| d_temp_storage, temp_bytes, d_in, d_out, row_length, stream)); | |
| } | |
| if (d_temp_storage == nullptr) { | |
| return stream_executor::gpu::ToStatus(hipcub::DeviceScan::InclusiveSum( | |
| d_temp_storage, temp_bytes, d_in, d_out, row_length, stream)); | |
| } |
| GTEST_SKIP() << "BF16 for row length > 128 has precision issues.", | ||
| absl::OkStatus(); | ||
| } |
There was a problem hiding this comment.
Nit: GTEST_SKIP() comma-expression in non-void function
GTEST_SKIP() expands to a return from the current scope in Google Test. The comma-expression GTEST_SKIP() << "...", absl::OkStatus() works because GTEST_SKIP() returns before the comma-operand is evaluated, but this relies on an implementation detail of the macro. A more explicit pattern would be:
if (type == xla::PrimitiveType::BF16 && row_length > 128) {
GTEST_SKIP() << "BF16 for row length > 128 has precision issues.";
return absl::OkStatus(); // unreachable, satisfies return type
}This is inherited from the CUDA test, so feel free to leave as-is for consistency if preferred.
Review SummaryGood implementation of hipCUB-based Main concern: The per-column host loop in the impl issues one Two minor nits posted inline (sentinel check, test skip pattern). 🤖 Generated with Claude Code |
316f257 to
015864a
Compare
- Add custom BlockScanKernel using rocPRIM block-level primitives (block_scan, block_load, block_store) for efficient batched scan - Single kernel launch with gridDim = column_length, one block per row - Register ROCm FFI handler for xla.gpu.ext.cub_scan with runtime type dispatch for all 12 supported types - Tuning via rocPRIM default_scan_config_base (architecture-aware) - Add parameterized kernel test and performance benchmarks
015864a to
85b82bb
Compare
Submission Checklist