Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion docs/source/user_guide/subgroup.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Subgroup ops live under `qd.simt.subgroup` and are written so the same Python so
|---------------------------------------------|------|--------|-----------------|------------------------------|
| `subgroup.shuffle(v, idx)` | yes | yes | yes | i32, u32, f32, f64, i64, u64 |
| `subgroup.shuffle_down(v, n)` | yes | yes\* | yes | i32, u32, f32, f64, i64, u64 |
| `subgroup.ballot(predicate)` | yes | yes | yes | i32 predicate → u32 bitmask |
| `subgroup.reduce_add(v, log2_size)` | yes | yes\* | yes | any type supporting `+` |
| `subgroup.reduce_all_add(v, log2_size)` | yes | yes | yes | any type supporting `+` |

Expand Down Expand Up @@ -42,6 +43,17 @@ Lane `i` returns the `value` held by lane `i + offset`. Lanes near the top of th
- Ops are issued under a full active mask on CUDA (`0xFFFFFFFF`). Call them from uniform control flow; calling from divergent control flow is undefined on most backends. (this means: all threads have to execute the shuffle)
- Subgroup size varies by backend (32 on NVIDIA, 32 or 64 on AMD, 32 in Vulkan compute on most GPUs).

### `ballot(predicate)`

Each lane evaluates `predicate` (an `i32`; non-zero is true, zero is false) and the result is a `u32` bitmask where bit `i` is set if lane `i`'s predicate was non-zero.

- Returns a `u32`. Bit 0 corresponds to lane 0, bit 1 to lane 1, etc.
- On CUDA, maps to `__ballot_sync(0xFFFFFFFF, predicate)`. On SPIR-V, maps to `OpGroupNonUniformBallot` (component 0 of the uvec4 result). On AMDGPU, maps to the `ballot.i32` intrinsic.
- The result covers the first 32 lanes. On AMDGPU CDNA with 64-wide wavefronts only the low 32 bits are returned; the upper 32 lanes are not represented. This is consistent with the 32-bit return type.
- Must be called from uniform control flow (all active lanes must execute the ballot).

Ballot is a building block for warp-cooperative algorithms: population counts (`popcount(ballot(cond))` counts how many lanes satisfy `cond`), prefix masks, and lane compaction.

### `reduce_add(value, log2_size)`

Sums `value` across `2**log2_size` consecutive lanes via a `shuffle_down` tree. The result is valid **in lane 0** of each group; other lanes hold partial sums and should be considered undefined.
Expand Down Expand Up @@ -78,6 +90,21 @@ def broadcast(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):

After the kernel, every lane in a subgroup holds the original value of its lane 0.

### Ballot: count how many lanes satisfy a condition

```python
@qd.kernel
def count_positive(a: qd.types.ndarray(dtype=qd.f32, ndim=1),
counts: qd.types.ndarray(dtype=qd.u32, ndim=1)):
qd.loop_config(block_dim=32)
for i in range(a.shape[0]):
mask = subgroup.ballot(qd.i32(a[i] > 0.0))
if subgroup.invocation_id() == 0:
counts[i // 32] = mask
```

After the kernel, `counts[g]` contains a bitmask of which lanes in group `g` had positive values. Use `popcount(mask)` on the host to get the count.

### Identity shuffle (each lane reads its own id)

Useful as a sanity check:
Expand Down Expand Up @@ -182,6 +209,7 @@ Every lane in each group of 32 sees the same `total`.

- Shuffles are register-to-register on CUDA (`__shfl_sync`, `__shfl_down_sync`) and on SPIR-V where the GPU has hardware support — typically a handful of cycles, no memory traffic.
- AMDGPU `shuffle` and `shuffle_down` both go through `ds_permute`/`ds_bpermute` today (LDS-routed, roughly tens of cycles).
- `ballot` is a single hardware instruction on all backends — one cycle on CUDA (`__ballot_sync`), one instruction on AMDGPU (`v_ballot_b32`), and `OpGroupNonUniformBallot` on SPIR-V.
- `reduce_add` and `reduce_all_add` both issue exactly `log2_size` shuffles and `log2_size` adds per call. No barriers, no shared memory, no launch overhead (they inline).
- Pick `reduce_all_add` over `reduce_add + broadcast` when you need the result in every lane — same cost, one fewer shuffle.
- 64-bit dtypes (`i64`, `u64`, `f64`) are emulated as two 32-bit shuffles on AMDGPU. Prefer 32-bit values when you have a choice.
Expand All @@ -190,5 +218,6 @@ Every lane in each group of 32 sees the same `total`.

- [tile16](tile16.md) — `Tile16x16` builds on `subgroup.shuffle` to implement register-resident 16x16 matrix tiles.
- `subgroup.invocation_id()` — returns this lane's subgroup-local index.
- `subgroup.size()` — returns the active subgroup size.
- `subgroup.group_size()` — returns the active subgroup size.
- `subgroup.ballot` — returns a u32 bitmask of lanes where the predicate is non-zero (see above).
- `subgroup.reduce_add` / `subgroup.reduce_all_add` — portable sized sum reductions built on `shuffle_down` / `shuffle` (see above).
5 changes: 5 additions & 0 deletions python/quadrants/lang/simt/subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def elect():
return impl.call_internal("subgroupElect", with_runtime_context=False)


def ballot(predicate):
return impl.call_internal("subgroupBallot", predicate, with_runtime_context=False)


def all_true(cond):
# TODO
pass
Expand Down Expand Up @@ -173,6 +177,7 @@ def shuffle_down(value, offset):
"barrier",
"memory_barrier",
"elect",
"ballot",
"all_true",
"any_true",
"all_equal",
Expand Down
2 changes: 2 additions & 0 deletions quadrants/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
llvm_val[stmt] = emit_amdgpu_shuffle_down(
/* value=*/llvm_val[stmt->args[0]],
/* dt=*/stmt->args[0]->ret_type, offset);
} else if (stmt->func_name == "subgroupBallot") {
llvm_val[stmt] = call("amdgpu_ballot_i32", llvm_val[stmt->args[0]]);
} else if (stmt->func_name == "subgroupInvocationId") {
llvm_val[stmt] = call("amdgpu_lane_id");
} else {
Expand Down
2 changes: 2 additions & 0 deletions quadrants/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
/* value=*/llvm_val[stmt->args[0]],
/* dt=*/stmt->args[0]->ret_type,
/* offset=*/llvm_val[stmt->args[1]]);
} else if (stmt->func_name == "subgroupBallot") {
llvm_val[stmt] = call("cuda_ballot_i32", llvm_val[stmt->args[0]]);
} else if (stmt->func_name == "subgroupInvocationId") {
llvm_val[stmt] = call("cuda_lane_id");
} else {
Expand Down
7 changes: 7 additions & 0 deletions quadrants/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,13 @@ void TaskCodegen::visit(InternalFuncStmt *stmt) {
auto index = ir_->query_value(stmt->args[1]->raw_name());
val = ir_->make_value(spv::OpGroupNonUniformBroadcast, value.stype,
ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), value, index);
} else if (stmt->func_name == "subgroupBallot") {
auto predicate = ir_->query_value(stmt->args[0]->raw_name());
auto pred_bool = ir_->make_value(spv::OpINotEqual, ir_->bool_type(), predicate,
ir_->int_immediate_number(ir_->i32_type(), 0));
auto ballot_vec = ir_->make_value(spv::OpGroupNonUniformBallot, ir_->t_v4_uint_,
ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), pred_bool);
val = ir_->make_value(spv::OpCompositeExtract, ir_->u32_type(), ballot_vec, 0);
} else if (inclusive_scan_ops.find(stmt->func_name) != inclusive_scan_ops.end()) {
auto arg = ir_->query_value(stmt->args[0]->raw_name());
auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type);
Expand Down
3 changes: 3 additions & 0 deletions quadrants/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ void IRBuilder::init_pre_defs() {
t_v3_uint_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v3_uint_).add_seq(t_uint32_, 3).commit(&global_);

t_v4_uint_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v4_uint_).add_seq(t_uint32_, 4).commit(&global_);

t_v4_fp32_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v4_fp32_).add_seq(t_fp32_, 4).commit(&global_);

Expand Down
1 change: 1 addition & 0 deletions quadrants/codegen/spirv/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ class IRBuilder {
SType t_v2_int_;
SType t_v3_int_;
SType t_v3_uint_;
SType t_v4_uint_;
SType t_v4_fp32_;
SType t_v3_fp32_;
SType t_v2_fp32_;
Expand Down
1 change: 1 addition & 0 deletions quadrants/inc/internal_ops.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ PER_INTERNAL_OP(subgroupBroadcast)
PER_INTERNAL_OP(subgroupShuffle)
PER_INTERNAL_OP(subgroupShuffleDown)
PER_INTERNAL_OP(subgroupShuffleUp)
PER_INTERNAL_OP(subgroupBallot)
PER_INTERNAL_OP(subgroupSize)
PER_INTERNAL_OP(subgroupInvocationId)
// subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor
Expand Down
1 change: 1 addition & 0 deletions quadrants/ir/type_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ void Operations::init_internals() {
POLY_OP(subgroupShuffle, false, Signature({}, {ValueT, !u32}, ValueT));
POLY_OP(subgroupShuffleDown, false, Signature({}, {ValueT, !u32}, ValueT));
POLY_OP(subgroupShuffleUp, false, Signature({}, {ValueT, !u32}, ValueT));
PLAIN_OP(subgroupBallot, u32, false, i32);
PLAIN_OP(subgroupSize, i32, false);
PLAIN_OP(subgroupInvocationId, i32, false);
// subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor
Expand Down
2 changes: 2 additions & 0 deletions quadrants/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ std::unique_ptr<llvm::Module> QuadrantsLLVMContext::module_from_file(const std::
patch_intrinsic("amdgpu_ds_bpermute", llvm::Intrinsic::amdgcn_ds_bpermute);
patch_intrinsic("amdgpu_mbcnt_lo", llvm::Intrinsic::amdgcn_mbcnt_lo);
patch_intrinsic("amdgpu_mbcnt_hi", llvm::Intrinsic::amdgcn_mbcnt_hi);
patch_intrinsic("amdgpu_ballot_w32", llvm::Intrinsic::amdgcn_ballot,
true, {llvm::Type::getInt32Ty(*ctx)});

link_module_with_amdgpu_libdevice(module);
patch_amdgpu_kernel_dim("block_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
Expand Down
9 changes: 9 additions & 0 deletions quadrants/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,15 @@ i32 amdgpu_mbcnt_hi(i32 mask, i32 base) {
return 0;
}

i32 amdgpu_ballot_w32(bool bit) {
__builtin_trap();
return 0;
}

i32 amdgpu_ballot_i32(i32 predicate) {
return amdgpu_ballot_w32((bool)predicate);
}

i32 amdgpu_lane_id() {
return amdgpu_mbcnt_hi(-1, amdgpu_mbcnt_lo(-1, 0));
}
Expand Down
81 changes: 81 additions & 0 deletions tests/python/test_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,87 @@ def foo():
assert abs(dst[i] - expected) < 1e-4 * abs(expected), f"lane {i}: got {dst[i]}, expected {expected}"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_all_true():
"""Ballot with all lanes voting true should return a full bitmask."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
result[i] = subgroup.ballot(1)

foo()

for i in range(N):
assert result[i] != 0, f"lane {i}: ballot returned 0, expected non-zero"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_all_false():
"""Ballot with all lanes voting false should return zero."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
result[i] = subgroup.ballot(0)

foo()

for i in range(N):
assert result[i] == 0, f"lane {i}: ballot returned {result[i]}, expected 0"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_even_lanes():
"""Even-numbered lanes vote true; odd lanes vote false."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
lane = subgroup.invocation_id()
result[i] = subgroup.ballot(1 - lane % 2)

foo()

mask = result[0]
assert mask & 0x1, "lane 0 should have voted true"
assert not (mask & 0x2), "lane 1 should have voted false"
assert mask & 0x4, "lane 2 should have voted true"
assert not (mask & 0x8), "lane 3 should have voted false"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_popcount():
"""Verify popcount of ballot(1) equals the subgroup size."""
N = 32
ballot_val = qd.field(dtype=qd.u32, shape=N)
sg_size = qd.field(dtype=qd.i32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
ballot_val[i] = subgroup.ballot(1)
sg_size[i] = subgroup.group_size()

foo()

bv = ballot_val[0]
sz = sg_size[0]
actual_popcount = bin(bv).count("1")
expected = min(sz, N)
assert actual_popcount == expected, f"popcount({bv:#x}) = {actual_popcount}, expected {expected} (subgroup size {sz})"


@test_utils.test(arch=qd.gpu)
def test_subgroup_invocation_id_range():
"""Verify invocation IDs are non-negative."""
Expand Down
Loading