diff --git a/docs/source/user_guide/subgroup.md b/docs/source/user_guide/subgroup.md index ba57733685..f232a8f5df 100644 --- a/docs/source/user_guide/subgroup.md +++ b/docs/source/user_guide/subgroup.md @@ -10,6 +10,7 @@ Subgroup ops live under `qd.simt.subgroup` and are written so the same Python so |---------------------------------------------|------|--------|-----------------|------------------------------| | `subgroup.shuffle(v, idx)` | yes | yes | yes | i32, u32, f32, f64, i64, u64 | | `subgroup.shuffle_down(v, n)` | yes | yes\* | yes | i32, u32, f32, f64, i64, u64 | +| `subgroup.ballot(predicate)` | yes | yes | yes | i32 predicate → u32 bitmask | | `subgroup.reduce_add(v, log2_size)` | yes | yes\* | yes | any type supporting `+` | | `subgroup.reduce_all_add(v, log2_size)` | yes | yes | yes | any type supporting `+` | @@ -42,6 +43,17 @@ Lane `i` returns the `value` held by lane `i + offset`. Lanes near the top of th - Ops are issued under a full active mask on CUDA (`0xFFFFFFFF`). Call them from uniform control flow; calling from divergent control flow is undefined on most backends. (this means: all threads have to execute the shuffle) - Subgroup size varies by backend (32 on NVIDIA, 32 or 64 on AMD, 32 in Vulkan compute on most GPUs). +### `ballot(predicate)` + +Each lane evaluates `predicate` (an `i32`; non-zero is true, zero is false) and the result is a `u32` bitmask where bit `i` is set if lane `i`'s predicate was non-zero. + +- Returns a `u32`. Bit 0 corresponds to lane 0, bit 1 to lane 1, etc. +- On CUDA, maps to `__ballot_sync(0xFFFFFFFF, predicate)`. On SPIR-V, maps to `OpGroupNonUniformBallot` (component 0 of the uvec4 result). On AMDGPU, maps to the `ballot.i32` intrinsic. +- The result covers the first 32 lanes. On AMDGPU CDNA with 64-wide wavefronts only the low 32 bits are returned; the upper 32 lanes are not represented. This is consistent with the 32-bit return type. +- Must be called from uniform control flow (all active lanes must execute the ballot). + +Ballot is a building block for warp-cooperative algorithms: population counts (`popcount(ballot(cond))` counts how many lanes satisfy `cond`), prefix masks, and lane compaction. + ### `reduce_add(value, log2_size)` Sums `value` across `2**log2_size` consecutive lanes via a `shuffle_down` tree. The result is valid **in lane 0** of each group; other lanes hold partial sums and should be considered undefined. @@ -78,6 +90,21 @@ def broadcast(a: qd.types.ndarray(dtype=qd.f32, ndim=1)): After the kernel, every lane in a subgroup holds the original value of its lane 0. +### Ballot: count how many lanes satisfy a condition + +```python +@qd.kernel +def count_positive(a: qd.types.ndarray(dtype=qd.f32, ndim=1), + counts: qd.types.ndarray(dtype=qd.u32, ndim=1)): + qd.loop_config(block_dim=32) + for i in range(a.shape[0]): + mask = subgroup.ballot(qd.i32(a[i] > 0.0)) + if subgroup.invocation_id() == 0: + counts[i // 32] = mask +``` + +After the kernel, `counts[g]` contains a bitmask of which lanes in group `g` had positive values. Use `popcount(mask)` on the host to get the count. + ### Identity shuffle (each lane reads its own id) Useful as a sanity check: @@ -182,6 +209,7 @@ Every lane in each group of 32 sees the same `total`. - Shuffles are register-to-register on CUDA (`__shfl_sync`, `__shfl_down_sync`) and on SPIR-V where the GPU has hardware support — typically a handful of cycles, no memory traffic. - AMDGPU `shuffle` and `shuffle_down` both go through `ds_permute`/`ds_bpermute` today (LDS-routed, roughly tens of cycles). +- `ballot` is a single hardware instruction on all backends — one cycle on CUDA (`__ballot_sync`), one instruction on AMDGPU (`v_ballot_b32`), and `OpGroupNonUniformBallot` on SPIR-V. - `reduce_add` and `reduce_all_add` both issue exactly `log2_size` shuffles and `log2_size` adds per call. No barriers, no shared memory, no launch overhead (they inline). - Pick `reduce_all_add` over `reduce_add + broadcast` when you need the result in every lane — same cost, one fewer shuffle. - 64-bit dtypes (`i64`, `u64`, `f64`) are emulated as two 32-bit shuffles on AMDGPU. Prefer 32-bit values when you have a choice. @@ -190,5 +218,6 @@ Every lane in each group of 32 sees the same `total`. - [tile16](tile16.md) — `Tile16x16` builds on `subgroup.shuffle` to implement register-resident 16x16 matrix tiles. - `subgroup.invocation_id()` — returns this lane's subgroup-local index. -- `subgroup.size()` — returns the active subgroup size. +- `subgroup.group_size()` — returns the active subgroup size. +- `subgroup.ballot` — returns a u32 bitmask of lanes where the predicate is non-zero (see above). - `subgroup.reduce_add` / `subgroup.reduce_all_add` — portable sized sum reductions built on `shuffle_down` / `shuffle` (see above). diff --git a/python/quadrants/lang/simt/subgroup.py b/python/quadrants/lang/simt/subgroup.py index 5386563514..b02b327590 100644 --- a/python/quadrants/lang/simt/subgroup.py +++ b/python/quadrants/lang/simt/subgroup.py @@ -18,6 +18,10 @@ def elect(): return impl.call_internal("subgroupElect", with_runtime_context=False) +def ballot(predicate): + return impl.call_internal("subgroupBallot", predicate, with_runtime_context=False) + + def all_true(cond): # TODO pass @@ -173,6 +177,7 @@ def shuffle_down(value, offset): "barrier", "memory_barrier", "elect", + "ballot", "all_true", "any_true", "all_equal", diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index 7eb23a7a2e..c49fff259a 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -397,6 +397,8 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { llvm_val[stmt] = emit_amdgpu_shuffle_down( /* value=*/llvm_val[stmt->args[0]], /* dt=*/stmt->args[0]->ret_type, offset); + } else if (stmt->func_name == "subgroupBallot") { + llvm_val[stmt] = call("amdgpu_ballot_i32", llvm_val[stmt->args[0]]); } else if (stmt->func_name == "subgroupInvocationId") { llvm_val[stmt] = call("amdgpu_lane_id"); } else { diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 64d5b0f283..347549ea23 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -738,6 +738,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { /* value=*/llvm_val[stmt->args[0]], /* dt=*/stmt->args[0]->ret_type, /* offset=*/llvm_val[stmt->args[1]]); + } else if (stmt->func_name == "subgroupBallot") { + llvm_val[stmt] = call("cuda_ballot_i32", llvm_val[stmt->args[0]]); } else if (stmt->func_name == "subgroupInvocationId") { llvm_val[stmt] = call("cuda_lane_id"); } else { diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index 2d6b601e37..bfdbd8d5e8 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -1414,6 +1414,13 @@ void TaskCodegen::visit(InternalFuncStmt *stmt) { auto index = ir_->query_value(stmt->args[1]->raw_name()); val = ir_->make_value(spv::OpGroupNonUniformBroadcast, value.stype, ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), value, index); + } else if (stmt->func_name == "subgroupBallot") { + auto predicate = ir_->query_value(stmt->args[0]->raw_name()); + auto pred_bool = ir_->make_value(spv::OpINotEqual, ir_->bool_type(), predicate, + ir_->int_immediate_number(ir_->i32_type(), 0)); + auto ballot_vec = ir_->make_value(spv::OpGroupNonUniformBallot, ir_->t_v4_uint_, + ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), pred_bool); + val = ir_->make_value(spv::OpCompositeExtract, ir_->u32_type(), ballot_vec, 0); } else if (inclusive_scan_ops.find(stmt->func_name) != inclusive_scan_ops.end()) { auto arg = ir_->query_value(stmt->args[0]->raw_name()); auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type); diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index 6cb87c0eed..83b561c24d 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -207,6 +207,9 @@ void IRBuilder::init_pre_defs() { t_v3_uint_.id = id_counter_++; ib_.begin(spv::OpTypeVector).add(t_v3_uint_).add_seq(t_uint32_, 3).commit(&global_); + t_v4_uint_.id = id_counter_++; + ib_.begin(spv::OpTypeVector).add(t_v4_uint_).add_seq(t_uint32_, 4).commit(&global_); + t_v4_fp32_.id = id_counter_++; ib_.begin(spv::OpTypeVector).add(t_v4_fp32_).add_seq(t_fp32_, 4).commit(&global_); diff --git a/quadrants/codegen/spirv/spirv_ir_builder.h b/quadrants/codegen/spirv/spirv_ir_builder.h index ef26fc6245..3bdc7c4887 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.h +++ b/quadrants/codegen/spirv/spirv_ir_builder.h @@ -574,6 +574,7 @@ class IRBuilder { SType t_v2_int_; SType t_v3_int_; SType t_v3_uint_; + SType t_v4_uint_; SType t_v4_fp32_; SType t_v3_fp32_; SType t_v2_fp32_; diff --git a/quadrants/inc/internal_ops.inc.h b/quadrants/inc/internal_ops.inc.h index c631501628..4afc415aff 100644 --- a/quadrants/inc/internal_ops.inc.h +++ b/quadrants/inc/internal_ops.inc.h @@ -30,6 +30,7 @@ PER_INTERNAL_OP(subgroupBroadcast) PER_INTERNAL_OP(subgroupShuffle) PER_INTERNAL_OP(subgroupShuffleDown) PER_INTERNAL_OP(subgroupShuffleUp) +PER_INTERNAL_OP(subgroupBallot) PER_INTERNAL_OP(subgroupSize) PER_INTERNAL_OP(subgroupInvocationId) // subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor diff --git a/quadrants/ir/type_system.cpp b/quadrants/ir/type_system.cpp index 9f87ed016c..9a7ee088f1 100644 --- a/quadrants/ir/type_system.cpp +++ b/quadrants/ir/type_system.cpp @@ -354,6 +354,7 @@ void Operations::init_internals() { POLY_OP(subgroupShuffle, false, Signature({}, {ValueT, !u32}, ValueT)); POLY_OP(subgroupShuffleDown, false, Signature({}, {ValueT, !u32}, ValueT)); POLY_OP(subgroupShuffleUp, false, Signature({}, {ValueT, !u32}, ValueT)); + PLAIN_OP(subgroupBallot, u32, false, i32); PLAIN_OP(subgroupSize, i32, false); PLAIN_OP(subgroupInvocationId, i32, false); // subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor diff --git a/quadrants/runtime/llvm/llvm_context.cpp b/quadrants/runtime/llvm/llvm_context.cpp index b094136d70..58ff6d7acc 100644 --- a/quadrants/runtime/llvm/llvm_context.cpp +++ b/quadrants/runtime/llvm/llvm_context.cpp @@ -517,6 +517,8 @@ std::unique_ptr QuadrantsLLVMContext::module_from_file(const std:: patch_intrinsic("amdgpu_ds_bpermute", llvm::Intrinsic::amdgcn_ds_bpermute); patch_intrinsic("amdgpu_mbcnt_lo", llvm::Intrinsic::amdgcn_mbcnt_lo); patch_intrinsic("amdgpu_mbcnt_hi", llvm::Intrinsic::amdgcn_mbcnt_hi); + patch_intrinsic("amdgpu_ballot_w32", llvm::Intrinsic::amdgcn_ballot, + true, {llvm::Type::getInt32Ty(*ctx)}); link_module_with_amdgpu_libdevice(module); patch_amdgpu_kernel_dim("block_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0)); diff --git a/quadrants/runtime/llvm/runtime_module/runtime.cpp b/quadrants/runtime/llvm/runtime_module/runtime.cpp index 88aa512542..5ea396a278 100644 --- a/quadrants/runtime/llvm/runtime_module/runtime.cpp +++ b/quadrants/runtime/llvm/runtime_module/runtime.cpp @@ -1644,6 +1644,15 @@ i32 amdgpu_mbcnt_hi(i32 mask, i32 base) { return 0; } +i32 amdgpu_ballot_w32(bool bit) { + __builtin_trap(); + return 0; +} + +i32 amdgpu_ballot_i32(i32 predicate) { + return amdgpu_ballot_w32((bool)predicate); +} + i32 amdgpu_lane_id() { return amdgpu_mbcnt_hi(-1, amdgpu_mbcnt_lo(-1, 0)); } diff --git a/tests/python/test_simt.py b/tests/python/test_simt.py index a6e490987a..a7d993a4b8 100644 --- a/tests/python/test_simt.py +++ b/tests/python/test_simt.py @@ -744,6 +744,87 @@ def foo(): assert abs(dst[i] - expected) < 1e-4 * abs(expected), f"lane {i}: got {dst[i]}, expected {expected}" +@test_utils.test(arch=qd.gpu) +def test_subgroup_ballot_all_true(): + """Ballot with all lanes voting true should return a full bitmask.""" + N = 32 + result = qd.field(dtype=qd.u32, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=N) + for i in range(N): + result[i] = subgroup.ballot(1) + + foo() + + for i in range(N): + assert result[i] != 0, f"lane {i}: ballot returned 0, expected non-zero" + + +@test_utils.test(arch=qd.gpu) +def test_subgroup_ballot_all_false(): + """Ballot with all lanes voting false should return zero.""" + N = 32 + result = qd.field(dtype=qd.u32, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=N) + for i in range(N): + result[i] = subgroup.ballot(0) + + foo() + + for i in range(N): + assert result[i] == 0, f"lane {i}: ballot returned {result[i]}, expected 0" + + +@test_utils.test(arch=qd.gpu) +def test_subgroup_ballot_even_lanes(): + """Even-numbered lanes vote true; odd lanes vote false.""" + N = 32 + result = qd.field(dtype=qd.u32, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=N) + for i in range(N): + lane = subgroup.invocation_id() + result[i] = subgroup.ballot(1 - lane % 2) + + foo() + + mask = result[0] + assert mask & 0x1, "lane 0 should have voted true" + assert not (mask & 0x2), "lane 1 should have voted false" + assert mask & 0x4, "lane 2 should have voted true" + assert not (mask & 0x8), "lane 3 should have voted false" + + +@test_utils.test(arch=qd.gpu) +def test_subgroup_ballot_popcount(): + """Verify popcount of ballot(1) equals the subgroup size.""" + N = 32 + ballot_val = qd.field(dtype=qd.u32, shape=N) + sg_size = qd.field(dtype=qd.i32, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=N) + for i in range(N): + ballot_val[i] = subgroup.ballot(1) + sg_size[i] = subgroup.group_size() + + foo() + + bv = ballot_val[0] + sz = sg_size[0] + actual_popcount = bin(bv).count("1") + expected = min(sz, N) + assert actual_popcount == expected, f"popcount({bv:#x}) = {actual_popcount}, expected {expected} (subgroup size {sz})" + + @test_utils.test(arch=qd.gpu) def test_subgroup_invocation_id_range(): """Verify invocation IDs are non-negative."""