Skip to content

Pemeliya/v0.9.0 split grid dim x#653

Open
pemeliya wants to merge 5 commits intorocm-jaxlib-v0.9.0from
pemeliya/v0.9.0_split_grid_dim_x
Open

Pemeliya/v0.9.0 split grid dim x#653
pemeliya wants to merge 5 commits intorocm-jaxlib-v0.9.0from
pemeliya/v0.9.0_split_grid_dim_x

Conversation

@pemeliya
Copy link

@pemeliya pemeliya commented Mar 4, 2026

Cherry-picked from openxla#35343

@@ -737,6 +756,17 @@ SmallColumnReductionFusion::SmallColumnReductionFusion(
num_blocks_ = {input_shape_[kColMajorKept]};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug risk: SmallColumnReductionFusion::ComputeReductionOutputIndexing and ComputeReductionInputIndexing (below) use bare getAffineDimExpr(3, mlir_context) as block_id without reconstructing the linear block ID from the split dimensions (i.e., bx + gpu_blocks_[0] * by).

This is inconsistent with ColumnReductionFusion, RowReductionFusion, and MultiRowReductionFusion, which all reconstruct the linear block ID. If SmallColumnReductionFusion ever hits a grid that needs splitting (i.e., gpu_blocks_[1] > 1), the kernel will compute incorrect indices because dim 4 (block y) is completely ignored in the index expression.

Should this follow the same bx + gpu_blocks_[0] * by pattern as the other reduction types?


TEST_F(GpuKernelTilingTest, LargeRowReduction) {
const char *kHlo = R"(
HloModule Test
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test uses Run() which actually executes the kernel on GPU with tensors of size 262144 x 262144 (~256 GB for s32). This will OOM on most hardware.

Other tests in this PR (MultiRowLargeReduceNonMultipleOf2, LargeLoopFusion) use CompileToExecutable() to avoid the memory issue. Should LargeRowReduction, LargeRowReductionNonMultipleOf2, and MultiRowLargeReduce also use CompileToExecutable instead of Run?

tests fixes

test fix

cleanup

test update

small improvement int64_t -> uint64_t

prevent taking reference to temporary

more robust fix for buffer_compare test

update

fixing failing test

Revert "Splitting gridDim.x into x and y if necessary. Adding grid-stride loops for buffer comparator and redzone checker kernels"

This reverts commit 3fb75fe.

removed unnecessary dep

fixing the unit test

adding const references in LaunchDimensions

fixing switch op after rebase

readded changes that have been lost due to numerous rebases

split up kernel_tiling test

updating the test accoring to the comments

disabled tuning

skipping the failing subtest on cuda side

fixing reduction_groups hlo test: gridDim.z is now used to distinguish between different reductions

reduced dimensions to make test pass on B200 GPU

removed config without autotuning

fixing clang formatting

updated the test
@claude
Copy link

claude bot commented Mar 4, 2026

Review Summary: This PR centralizes grid dimension splitting into MaybeSplitGridDimensionX, replacing the ROCm do-while workaround. See inline comments for findings: potential SmallColumnReductionFusion bug, missing Z-dim validation, behavioral change in kMinorReducedElementsPerThread, test OOM risk, and minor IWYU issues.

@pemeliya pemeliya force-pushed the pemeliya/v0.9.0_split_grid_dim_x branch from a9bc41d to 79d9658 Compare March 4, 2026 13:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR cherry-picks upstream changes to support splitting large GPU grid dimension X across X/Y (and moving some logical dimensions to Z) to stay within hardware/ROCm launch constraints, while improving diagnostics and adding regression coverage.

Changes:

  • Add EmitterBase::MaybeSplitGridDimensionX and apply it to loop and reduction fusion launch dimensions.
  • Update reduction grouping logic to use blockIdx.z and adjust indexing maps accordingly.
  • Improve ROCm kernel launch failure messages and extend GPU tiling tests for very large reductions/fusions.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
xla/stream_executor/rocm/rocm_stream.cc Enriches ROCm kernel launch error message with grid/block/shared-mem details.
xla/service/gpu/tests/gpu_kernel_tiling_test.cc Adds large-shape tiling/regression tests and updates an error-message regex.
xla/service/gpu/launch_dimensions.h Changes LaunchDimensions accessors to return const references.
xla/service/gpu/launch_dimensions.cc Removes now-unused includes.
xla/service/gpu/BUILD Drops an unused dependency.
xla/backends/gpu/codegen/fusion_emitter.cc Updates launch-dimension limit checks and metadata emission to use cached dims.
xla/backends/gpu/codegen/emitters/tests/reduce_row/reduction_groups.hlo Updates FileCheck expectations to reflect block_id z usage.
xla/backends/gpu/codegen/emitters/reduction_base.cc Moves group-id constraint from blockIdx.y to blockIdx.z.
xla/backends/gpu/codegen/emitters/reduction.h Tracks “virtual” vs “real” block dimensions via gpu_blocks_.
xla/backends/gpu/codegen/emitters/reduction.cc Applies X-dimension splitting and updates indexing/launch dims for reduction emitters.
xla/backends/gpu/codegen/emitters/loop.cc Applies X-dimension splitting to loop fusion launch dimensions.
xla/backends/gpu/codegen/emitters/emitter_base.h / .cc Introduces MaybeSplitGridDimensionX helper and adjusts ID emission to avoid lifetime issues.
xla/backends/gpu/codegen/emitters/BUILD Drops an unused dependency.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +438 to +458
TEST_F(GpuKernelTilingTest, MultiRowLargeReduce) {
const char *kHlo = R"(
HloModule Test
reduceOp {
X = s32[] parameter(1)
Y = s32[] parameter(0)
ROOT Z = s32[] maximum(X, Y)
}
ENTRY MultiRowLargeReduce {
A = s32[262144]{0} parameter(0)
B = s32[262144,262144]{1,0} broadcast(A), dimensions={0}
I = s32[262144,262144]{1,0} iota(), iota_dimension=1
Z = s32[262144,262144]{1,0} add(B, I)
BB = s32[262144,4096,64]{2,1,0} reshape(Z)
CC = s32[] constant(0)
R = s32[262144,4096]{1,0} reduce(BB, CC), dimensions={2}, to_apply=reduceOp
ROOT O = s16[262144,4096]{1,0} convert(R)
})";
ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnVerifiedModule(kHlo));
EXPECT_TRUE(Run(std::move(hlo_module), /*run_hlo_passes*/true));
}
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test calls Run(...) with shapes that imply enormous intermediate buffers (e.g. s32[262144,262144]), which will allocate huge fake inputs and is likely to OOM/time out in CI. Consider making this a compile-only test or reducing dimensions while preserving the intended launch-dimension behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +649 to +656
VLOG(3) << absl::StrFormat(
"ColumnReductionFusion selected parameters: num_threads "
"= [%s], tile_sizes_per_thread = [%s], "
"num_blocks = [%s] real_blocks_ = [%d, %d] ",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::StrFormat uses printf-style format strings, but this message uses %d for gpu_blocks_[0]/gpu_blocks_[1] which are uint64_t. This can lead to incorrect formatting or even a compile-time error with Abseil's type checking. Use an appropriate unsigned/64-bit specifier (or cast to a matching type) for these values.

Copilot uses AI. Check for mistakes.
Comment on lines +416 to +436
TEST_F(GpuKernelTilingTest, LargeRowReductionNonMultipleOf2) {
const char *kHlo = R"(
HloModule Test
reduceOp {
X = s32[] parameter(1)
Y = s32[] parameter(0)
ROOT Z = s32[] maximum(X, Y)
}
ENTRY RowLargeReduce {
A = s32[762145]{0} parameter(0)
B = s32[762145,776223]{1,0} broadcast(A), dimensions={0}
I = s32[762145,776223]{1,0} iota(), iota_dimension=1
Z = s32[762145,776223]{1,0} add(B, I)
BB = s32[762145,999,777]{2,1,0} reshape(B)
CC = s32[] constant(0)
R = s32[762145,999]{1,0} reduce(BB, CC), dimensions={2}, to_apply=reduceOp
ROOT O = s16[762145,999] convert(R)
})";
ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnVerifiedModule(kHlo));
EXPECT_TRUE(Run(std::move(hlo_module), /*run_hlo_passes*/true));
}
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test calls Run(...) with very large shapes (e.g. s32[762145,776223]), which will allocate massive fake input buffers and is likely to OOM/time out. Prefer a compile-only check or smaller shapes that still trigger the grid-dimension split path.

Copilot uses AI. Check for mistakes.
Comment on lines +765 to +769
"num_blocks = [%s] real_blocks_ = [%d, %d] ",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::StrFormat here uses %d for gpu_blocks_[0]/gpu_blocks_[1] (type uint64_t). Please use a format specifier compatible with 64-bit unsigned integers (or cast to the expected type) to avoid incorrect output / format-string type-check failures.

Suggested change
"num_blocks = [%s] real_blocks_ = [%d, %d] ",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
"num_blocks = [%s] real_blocks_ = [%llu, %llu] ",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
static_cast<unsigned long long>(gpu_blocks_[0]),
static_cast<unsigned long long>(gpu_blocks_[1]));

Copilot uses AI. Check for mistakes.
Comment on lines 903 to +911
VLOG(3) << absl::StrFormat(
"RowReductionFusion::RowReductionFusion selected parameters: num_threads "
"RowReductionFusion selected parameters: num_threads "
"= [%s], tile_sizes_per_thread = [%s], tile_sizes_per_block = [%s], "
"num_blocks = [%s]",
"num_blocks = [%s] real_blocks_ = [%d, %d]",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(tile_sizes_per_block_, ","),
absl::StrJoin(num_blocks_, ","));
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::StrFormat uses %d for gpu_blocks_[0]/gpu_blocks_[1], but these are uint64_t. This may print incorrectly or fail Abseil's format checks. Use the correct unsigned/64-bit format conversion (or cast) for these values.

Copilot uses AI. Check for mistakes.
Comment on lines +1017 to +1021
"num_blocks = [%s] real_blocks_ = [%d, %d]",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absl::StrFormat format string uses %d for gpu_blocks_[0]/gpu_blocks_[1] (both uint64_t). Please switch to a format specifier that matches 64-bit unsigned values (or cast) to avoid incorrect formatting or build errors from format checking.

Suggested change
"num_blocks = [%s] real_blocks_ = [%d, %d]",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
"num_blocks = [%s] real_blocks_ = [%llu, %llu]",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
static_cast<unsigned long long>(gpu_blocks_[0]),
static_cast<unsigned long long>(gpu_blocks_[1]));

Copilot uses AI. Check for mistakes.
Comment on lines +237 to +238
if (num_blocks_x <= limit.x &&
(!is_rocm || num_blocks_x * num_threads_x <= rocm_limit)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: potential issue — num_blocks_x * num_threads_x can silently overflow uint64_t if both values are large (e.g. each near 2^32). Since rocm_limit is UINT32_MAX, a safer overflow-proof check would be:

Suggested change
if (num_blocks_x <= limit.x &&
(!is_rocm || num_blocks_x * num_threads_x <= rocm_limit)) {
if (num_blocks_x <= limit.x &&
(!is_rocm || num_blocks_x <= rocm_limit / num_threads_x)) {

This avoids the multiplication entirely and handles the edge case correctly.

Comment on lines +258 to +259
}
VLOG(1) << num_blocks_x << " splitting as: " << dimx << "x" << dimy
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor robustness: the expression dimx*dimy - num_blocks_x is an unsigned subtraction. While dimx * dimy >= num_blocks_x should always hold by construction (ceiling division), this invariant isn't asserted. If a future refactor breaks it, the subtraction silently wraps. Consider adding a DCHECK_GE(dimx * dimy, num_blocks_x) before this line, or using a signed cast.

@claude
Copy link

claude bot commented Mar 4, 2026

Re-review Summary: Previous 10 inline findings remain open. 3 new findings posted inline: (1) potential uint64_t overflow in ROCm grid limit check, (2) hardcoded block dimensions in ROCm test path, (3) unsigned subtraction without guard in VLOG diagnostic.

@claude
Copy link

claude bot commented Mar 4, 2026

Re-review Summary

Several previous findings have been addressed -- replied Resolved on 6 threads (include array added, CeilOfRatio used, VLOG(3), CHECK_EQ, kMinorReducedElementsPerThread confirmed intentional).

6 prior findings remain open with no developer response (SmallColumnReductionFusion indexing, <= vs < limit check, large-tensor Run() tests, overflow in multiplication, ROCm error regex, unsigned subtraction).

2 new findings posted inline (include ordering in emitter_base.h and reduction.h).

@claude
Copy link

claude bot commented Mar 4, 2026

Re-review Summary (revision 2) - 4 previously raised issues resolved (alphabetical ordering, hardcoded regex values, missing Z-dimension validation). 4 prior findings still open (original inline comments remain): SmallColumnReductionFusion missing linear block ID reconstruction, tests using Run() with huge tensors, potential uint64_t overflow, missing DCHECK_GE guard. 1 new nit posted inline: double closing brace formatting.

Comment on lines 756 to +769
num_blocks_ = {input_shape_[kColMajorKept]};
loop_size_ = CeilOfRatio(input_shape_[1] * input_shape_[2],
vector_size_ * num_threads_[0]);
gpu_blocks_ = MaybeSplitGridDimensionX(Product(num_threads_),
Product(num_blocks_), analysis_.device_info());

VLOG(3) << absl::StrFormat(
"SmallColumnReductionFusion selected parameters: num_threads "
"= [%s], tile_sizes_per_thread = [%s], "
"num_blocks = [%s] real_blocks_ = [%d, %d] ",
absl::StrJoin(num_threads_, ","),
absl::StrJoin(tile_sizes_per_thread_, ","),
absl::StrJoin(num_blocks_, ","),
gpu_blocks_[0], gpu_blocks_[1]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallColumnReductionFusion calls MaybeSplitGridDimensionX here, but its ComputeReductionOutputIndexing and ComputeReductionInputIndexing methods (lines 772-807) still use raw getAffineDimExpr(3, mlir_context) as the block ID without reconstructing the linear block ID via bx + gpu_blocks_[0] * by.

Compare with ColumnReductionFusion, RowReductionFusion, and MultiRowReductionFusion, which all reconstruct the linear block ID:

auto bx = mlir::getAffineDimExpr(3, mlir_context),
     by = mlir::getAffineDimExpr(4, mlir_context);
auto block_id = DelinearizeInBoundsIndex(bx + gpu_blocks_[0] * by, num_blocks_);

When gpu_blocks_[1] > 1 (i.e., the grid was actually split), SmallColumnReductionFusion will compute wrong indices because d3 (blockIdx.x) only holds the low-order part of the block ID. This is a correctness bug for the scenario this PR is designed to handle.

Comment on lines +236 to +238
if (num_blocks_x <= limit.x &&
(!is_rocm || num_blocks_x * num_threads_x <= rocm_limit)) {
return {num_blocks_x, 1};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiplication num_blocks_x * num_threads_x can overflow uint64_t for extreme inputs. In practice, current GPU shapes keep these values well within range, but there is no guard. Consider adding an overflow check or at least a comment documenting the assumption that the product fits in 64 bits.

Comment on lines +245 to +260
if (dimy > limit.y) {
// We could not find the proper power-of-two dim Y => use max gridY
dimy = limit.y;
dimx = CeilOfRatio(num_blocks_x, dimy);
break;
}
// num_blocks_x might not be divided evenly by dimy, so we round up.
dimx = CeilOfRatio(num_blocks_x, dimy);
if (dimx <= limit.x) {
// We have an extra requirement on ROCM to check
if (!is_rocm || dimx * num_threads_x <= rocm_limit) break;
}
}
VLOG(3) << num_blocks_x << " splitting as: " << dimx << "x" << dimy
<< " wasted blocks: " << (dimx*dimy - num_blocks_x);
return {dimx, dimy};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the fallback path fires (line 247: dimy = limit.y, dimx = CeilOfRatio(num_blocks_x, dimy)), the function may return dimx > limit.x -- an invalid grid that will be rejected downstream by AnnotateKernelLaunchDimensions. This is technically correct (the error surfaces), but failing silently here and relying on a distant downstream check is fragile. Consider adding a LOG(WARNING) or a DCHECK when the fallback produces dimx > limit.x to make debugging easier.

Also, the VLOG on line 258 computes dimx*dimy - num_blocks_x using unsigned arithmetic. Since dimx = CeilOfRatio(num_blocks_x, dimy) guarantees dimx*dimy >= num_blocks_x, this is safe. But adding a DCHECK_GE(dimx * dimy, num_blocks_x) before the subtraction would make the invariant explicit.

Comment on lines +227 to +230
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info)
{
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor style: Google C++ style places the opening brace on the same line as the function signature, not on a separate line.

Suggested change
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info)
{
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info) {

Comment on lines +395 to +414
TEST_F(GpuKernelTilingTest, LargeRowReduction) {
const char *kHlo = R"(
HloModule Test
reduceOp {
X = s32[] parameter(1)
Y = s32[] parameter(0)
ROOT Z = s32[] maximum(X, Y)
}
ENTRY RowLargeReduce {
A = s32[262144]{0} parameter(0)
B = s32[262144,262144]{1,0} broadcast(A), dimensions={0}
I = s32[262144,262144]{1,0} iota(), iota_dimension=1
Z = s32[262144,262144]{1,0} add(B, I)
BB = s32[262144,512,512]{2,1,0} reshape(Z)
CC = s32[] constant(0)
ROOT R = s32[262144,512]{1,0} reduce(BB, CC), dimensions={2}, to_apply=reduceOp
})";
ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnVerifiedModule(kHlo));
EXPECT_TRUE(Run(std::move(hlo_module), /*run_hlo_passes*/ true));
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LargeRowReduction uses Run() which actually executes the kernel, materializing tensors like s32[262144,262144] (~256 GB). This will OOM on any existing GPU. The same issue applies to LargeRowReductionNonMultipleOf2 (line 435) and MultiRowLargeReduce (line 457).

Consider using CompileToExecutable instead of Run(), as was done for MultiRowLargeReduceNonMultipleOf2 and LargeLoopFusion. This verifies the compilation succeeds (including grid splitting) without needing to allocate the full tensor.

} else {
num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
}
int64_t kMinorReducedElementsPerThread = 16;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial value of kMinorReducedElementsPerThread changed from 8 (doubled to 16 on first loop iteration) to a fixed 16, and the do-while retry loop that increased this value for ROCm was removed. The net starting value is the same (16), but the old code could retry with 32, 64, etc. up to 65536 to shrink the grid on ROCm.

Now that MaybeSplitGridDimensionX handles oversized grids, the retry loop is correctly unnecessary. However, this is a behavioral change: the old code reduced parallelism (larger tiles per thread) to fit the grid; the new code preserves parallelism and splits the grid instead. This should generally be better for performance but changes tile configurations for previously-affected workloads. Worth a note in the PR description.

Comment on lines +146 to +149
// virtual grid dimension: used by LLVM internally
absl::InlinedVector<int64_t, 4> num_blocks_;
// real block dimensions used to launch a fusion kernel
std::array<uint64_t, 2> gpu_blocks_ = {0, 0};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default initializer {0, 0} means gpu_blocks_ starts at zero. All current subclass constructors set it, so this is safe today. But as a defensive measure, consider adding DCHECK_GT(gpu_blocks_[0], 0) in launch_dimensions() or GetIndexingMap() to catch any future subclass that forgets to initialize it.

Comment on lines +759 to +760
gpu_blocks_ = MaybeSplitGridDimensionX(Product(num_threads_),
Product(num_blocks_), analysis_.device_info());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallColumnReductionFusion sets gpu_blocks_ here via MaybeSplitGridDimensionX, but the corresponding indexing functions (ComputeReductionOutputIndexing at line 772 and ComputeReductionInputIndexing at line 786) still use getAffineDimExpr(3, mlir_context) directly as the block ID, without reconstructing the linear block ID via bx + gpu_blocks_[0] * by.

Every other reduction subclass (ColumnReductionFusion, RowReductionFusion, MultiRowReductionFusion) was updated to reconstruct the linear block ID from the split grid dimensions, but SmallColumnReductionFusion was not. When gpu_blocks_[1] > 1 (i.e., the grid X dimension was split), GetIndexingMap sets dim 3 to range [0, gpu_blocks_[0]) which is smaller than Product(num_blocks_), so the indexing maps would use only a subset of block indices, producing incorrect results.

The fix should apply the same bx + gpu_blocks_[0] * by pattern used in the other subclasses to both ComputeReductionOutputIndexing and ComputeReductionInputIndexing.

ROOT R = s32[262144,512]{1,0} reduce(BB, CC), dimensions={2}, to_apply=reduceOp
})";
ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnVerifiedModule(kHlo));
EXPECT_TRUE(Run(std::move(hlo_module), /*run_hlo_passes*/ true));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LargeRowReduction, LargeRowReductionNonMultipleOf2, and MultiRowLargeReduce tests use Run(), which actually executes the kernel on the GPU. The intermediate tensors (e.g., s32[262144,262144] = 256 GB logical size, s32[762145,776223] = ~2.2 TB) are enormous. Even though fusion avoids fully materializing broadcast/iota intermediates, the output tensors themselves can be large (e.g., s32[262144,512] = 512 MB, s32[762145,999] ~3 GB).

These tests risk OOM on CI machines with limited GPU memory. Consider using CompileToExecutable() instead (as done in MultiRowLargeReduceNonMultipleOf2 and LargeLoopFusion) to verify compilation succeeds without executing. Alternatively, reduce the tensor sizes to something that fits in typical CI GPU memory while still triggering the grid split logic.

bool is_rocm = info.gpu_compute_capability().IsRocm();
// Add an extra condition for ROCM backend
if (num_blocks_x <= limit.x &&
(!is_rocm || num_blocks_x * num_threads_x <= rocm_limit)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiplication num_blocks_x * num_threads_x is performed on two uint64_t values. In practice, num_blocks_x is bounded by the grid limit (~2^31) and num_threads_x is at most 1024, so this won't overflow. However, there is no static guarantee of this -- the function accepts arbitrary uint64_t values. Consider adding a brief comment noting the safe range assumption, or using a checked multiplication.

Same concern applies to dimx * num_threads_x on line 255.

Comment on lines +227 to +230
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info)
{
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The opening brace on a new line ({ alone on line 230) deviates from Google C++ style, which places the opening brace at the end of the function declaration line. This should be:

Suggested change
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info)
{
/* static */ std::array<uint64_t, 2> EmitterBase::MaybeSplitGridDimensionX(
uint64_t num_threads_x, uint64_t num_blocks_x,
const se::DeviceDescription& info) {

Comment on lines +529 to 531
"Kernel '.*' launch needs more blocks [(][0-9]+, 65536, [0-9]+[)] "
"than allowed by hardware [(]2147483647, 65536, [0-9]+[)]"));
} else {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex [0-9]+ for the block dimension values makes this test weaker than the original -- it no longer verifies the specific grid dimensions being produced by the splitting algorithm. Consider capturing the Y-dimension limit value explicitly (e.g., 65536 for ROCm, 65535 for CUDA) in the "allowed by hardware" portion, and verifying at least one concrete dimension in the "launch needs more blocks" portion, to validate that the splitting algorithm actually hit the expected ceiling.

Also, the parameter shape was changed from f32[16,1048576,1024,1024] to f32[1048576,1048576,1024,1024] (2^60 elements, ~4 EiB). A brief comment explaining why this specific shape was chosen would help future readers understand the intent.

num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced;
}
int64_t kMinorReducedElementsPerThread = 16;

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial value of kMinorReducedElementsPerThread changed from 8 (doubled to 16 by the first iteration of the old do-while loop) to a direct 16. This is semantically equivalent to the first iteration of the old loop. However, the old code had a do-while loop that could increase this value further on ROCm when the grid was too large, effectively trading parallelism for fewer blocks. That fallback behavior has been removed -- the grid splitting in MaybeSplitGridDimensionX replaces it. Worth confirming in testing that this doesn't regress performance for cases where the old loop would have iterated multiple times (i.e., cases where even after one doubling, the ROCm uint32_t limit was still exceeded).

@claude
Copy link

claude bot commented Mar 5, 2026

Review Summary: 7 inline findings posted. The most critical is that SmallColumnReductionFusion indexing methods are missing the linear block ID reconstruction (bx + gpu_blocks_[0] * by) that was added to every other reduction subclass -- this will produce wrong results when the grid is split. Three tests use Run() with tensors too large to fit in GPU memory. Additional findings cover overflow risk, missing validation in the fallback path, a style issue, and defensive initialization.

@claude
Copy link

claude bot commented Mar 5, 2026

Re-review Summary (Claude): All findings already covered by previous inline comments. No new issues. Key open items: SmallColumnReductionFusion indexing bug, OOM-risk tests, uint64 overflow, brace style.

@claude
Copy link

claude bot commented Mar 5, 2026

Re-review: one prior finding resolved (closing braces fixed). No new issues. Key outstanding items: SmallColumnReductionFusion indexing bug, tests materializing huge tensors via Run(), and style nits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants