Skip to content

[SPIRV] Feature Parity Atomics & Shared Array#432

Merged
duburcqa merged 5 commits intomainfrom
duburcqa/fix_metal_atomics_shared_mem
Apr 9, 2026
Merged

[SPIRV] Feature Parity Atomics & Shared Array#432
duburcqa merged 5 commits intomainfrom
duburcqa/fix_metal_atomics_shared_mem

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented Mar 29, 2026

Brief Summary

  1. Fix float atomics on shared memory for Metal and Vulkan:
    • allocate shared float arrays as uint, bitcast at load/store
    • only retype arrays targeted by atomics (pre-scan via scan_shared_atomic_allocs)
  2. Add f16 shared memory float atomics support:
    • back f16 arrays with u32 (Metal/Vulkan lack 16-bit atomics), with width conversion at load/store/CAS boundaries
  3. Add support of shared memory of arbitrary dtype on Metal and Vulkan:
    • flatten nested tensor types (vec3 etc.), handle OpPtrAccessChain for component access
  4. Add official support of multiple shared arrays on Metal:
    • enable existing test on Metal
  5. Fix shared memory float atomics on GPUs with native float atomic support:
    • handle dest_is_ptr before at_buffer, disable native float atomics for uint-retyped shared arrays
┌─────┬────────────────────────────────────────────────────────────────────┬──────────┐
│  #  │                               Commit                               │  Delta   │
├─────┼────────────────────────────────────────────────────────────────────┼──────────┤
│ 1   │ [Test] Expand unit test coverage.                                  │ +96/-6   │
├─────┼────────────────────────────────────────────────────────────────────┼──────────┤
│ 2   │ [SPIRV] Fix float atomic bugs: f16/f64 min/max, CAS type mismatch. │ +85/-7   │
├─────┼────────────────────────────────────────────────────────────────────┼──────────┤
│ 3   │ [SPIRV] Add Int64Atomics and shared float atomic cap detection.    │ +32/-0   │
├─────┼────────────────────────────────────────────────────────────────────┼──────────┤
│ 4   │ [SPIRV] Add CAS-based float atomic emulation for shared memory.    │ +444/-10 │
├─────┼────────────────────────────────────────────────────────────────────┼──────────┤
│ 5   │ [SPIRV] Add nested tensor type flattening.                         │ +40/-25  │
└─────┴────────────────────────────────────────────────────────────────────┴──────────┘

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch 4 times, most recently from af08498 to 02fb882 Compare March 30, 2026 07:05
@duburcqa duburcqa changed the title Fix support of float atomics on shared memory for Apple Metal. Fix support of float atomics and generic dtypes on shared memory for Apple Metal. Mar 30, 2026
@duburcqa duburcqa changed the title Fix support of float atomics and generic dtypes on shared memory for Apple Metal. Add support of float atomics and generic dtypes to shared memory for Apple Metal. Mar 30, 2026
@duburcqa duburcqa changed the title Add support of float atomics and generic dtypes to shared memory for Apple Metal. Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal. Mar 30, 2026
@duburcqa
Copy link
Copy Markdown
Contributor Author

I was assisted by Claude Opus to write this PR. I have read every line added in this PR, and reviewed the lines. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus.

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch 4 times, most recently from c7cbb8b to 4a42abe Compare March 30, 2026 08:43
Comment thread quadrants/codegen/spirv/spirv_codegen.cpp Outdated
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
DataType elem_dt = tensor_type->get_element_type();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elem_dt and elem_type are very confusing. Could we either give more intuitive names, or at least add a comment on what is the difference between them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be better now.

// float atomics).
if (alloca->is_shared && is_real(elem_dt)) {
elem_type =
ir_->get_primitive_type(ir_->get_quadrants_uint_type(elem_dt));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear to me from the name what get_quadrants_uint_type does. specifically around nubmer of bits. Could we add a comment to clarify what is happening in this line, specifically around nubmer of bits?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should be better now.

spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
auto dt = stmt->element_type().ptr_removed();
// Flatten nested tensor types to scalar (e.g., vec3 to f32)
if (auto nested = dt->cast<TensorType>()) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems very similar to what happens above. Could this be factorized into a helper function?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here as for the AllocaStmt: lets have a minimal footprint in this function, that peels off thinknig about uint-backed float shared arrays into a separate function in aseparate file pelase.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

ir_->get_primitive_type(dt), origin_val.stype.storage_class);
auto elem_type = ir_->get_primitive_type(dt);
if (shared_float_retyped_.count(stmt->origin)) {
elem_type = ir_->get_primitive_type(ir_->get_quadrants_uint_type(dt));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for questoin about helper function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment thread quadrants/codegen/spirv/detail/spirv_codegen.h Outdated
spirv::SType ptr_type = ir_->get_pointer_type(
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
auto elem_type = ir_->get_primitive_type(dt);
if (shared_float_retyped_.count(stmt->origin)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about what this if statement is checking for intuitively

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val);
ptr_val = ir_->add(origin_val, offset_bytes);
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
} else if (origin_val.stype.flag == TypeKind::kPtr) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about what this new else if block is checking for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots more new code. lets mov to the uint backed shared array file please.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, shared arrays never reached the float-atomic path; now they can, so all branches guard at_buffer() with dest_is_ptr ternaries.

stmt->op_type == AtomicOpType::add) {
addr_ptr = at_buffer(stmt->dest, dt);
} else {
addr_ptr = dest_is_ptr
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we refactorziing by the way? It's just extra codde for me to read. Is this reafctoziing necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

}

// Shared float arrays are retyped to uint, so native float atomics
// (which require a float pointer) cannot be used on them.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. I thought the purpose of changing the backing type to uint was to enable the spirv atomics? Could you give a little more clarification (in the comments) about this point please.

Also, how do we know if we ar edealing with a shared array here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comment on the PR itself to clarify this.

Comment thread tests/python/test_shared_array.py Outdated
Comment thread tests/python/test_shared_array.py Outdated
def test_shared_array_float_atomics(op):
N = 256
block_dim = 32
total = block_dim * (block_dim - 1) / 2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total what? total_threads? Why are we dividing by 2.0? Oh, perhaps we are doing some kind of arithmetic progression or similar, and this is the expected_sum of that progression? Could we update the name to make the meaning more intuitive please. By the way, this calculation could be done using ints. Could we make this something that needs actual floats? Like, e.g. multiply each term in the progression by 0.333, which is pretty incompatible with binary representation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Comment thread tests/python/test_shared_array.py Outdated
Comment thread tests/python/test_shared_array.py Outdated
@hughperkins
Copy link
Copy Markdown
Collaborator

Opus review:

Thoughts

What's good:
• The approach of retyping float shared arrays to uint and bitcasting at the boundaries is a well-known pattern for working around missing float atomics. It's clean and
localized.
• The taint-tracking via shared_float_retyped_ is a reasonable way to propagate the retyping through the IR without modifying the IR itself.
• Test coverage looks solid — the new test_shared_array_float_atomics is parametrized over all four atomic ops.

Things I'd flag:

  1. The shared_float_retyped_ set is a std::unordered_set<const Stmt*> tracking raw pointers. This is fine as long as stmts outlive the codegen pass, but it's worth
    confirming there's no stmt lifetime issue. Given this is a visitor pass over existing IR, it should be safe.
  2. The retyping applies to all real types on shared memory, not just f32. If someone uses f16 or f64 shared arrays, the bitcast to the corresponding uint type should work,
    but there's no explicit test for those. The get_quadrants_uint_type utility presumably handles the width mapping, but worth verifying for f16 especially.
  3. The early dest_is_ptr branch in AtomicOpStmt unconditionally uses dest_val for shared memory, which means it never goes through at_buffer. This is correct (shared
    memory isn't in a buffer), but the original code had some interleaving of the dest_is_ptr check inside each float type branch. The refactor is cleaner, but it's a behavior
    change — previously a shared f32 array on a GPU that does report spirv_has_atomic_float_add would have used dest_val directly (correct), and now it still does (correct),
    but the native-atomics flag is also explicitly suppressed. The comment explains why (shared arrays are uint-retyped so native float atomics can't work), which is good.
  4. The test_shared_array_float_atomics test initializes sharr[tid] = qd.f32(tid) then atomically accumulates into sharr[0]. The initial value of sharr[0] is 0.0 (from
    tid=0), then all threads do atomic_op(sharr[0], tid). For sub, the expected result is -total which assumes the initial value 0.0 stays as-is and only the atomic subtracts
    are applied. That's correct since atomic_sub(sharr[0], 0) is a no-op for tid=0.
  5. Minor: test_shared_array_matrix dropped debug=True — intentional? Could be that debug mode was causing issues on Vulkan/Metal or just wasn't needed.

Overall, this is a well-structured branch. The Metal/Vulkan shared memory work is the kind of backend plumbing that's easy to get wrong, but the approach here is principled and
the test coverage backs it up.

@hughperkins
Copy link
Copy Markdown
Collaborator

from the AI review, pelase could we address at least:
2. The retyping applies to all real types on shared memory, not just f32. If someone uses f16 or f64 shared arrays, the bitcast to the corresponding uint type should work,
but there's no explicit test for those. The get_quadrants_uint_type utility presumably handles the width mapping, but worth verifying for f16 especially.
5. Minor: test_shared_array_matrix dropped debug=True — intentional? Could be that debug mode was causing issues on Vulkan/Metal or just wasn't needed.

@hughperkins
Copy link
Copy Markdown
Collaborator

(so AI and myself concur about the ambiguity over what get_quadrants_uint_type does)

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 4a42abe to 4b0ea62 Compare March 30, 2026 15:42
@duburcqa duburcqa changed the title Feature Parity Atomics & Shared Array [SPIRV] Feature Parity Atomics & Shared Array Apr 9, 2026
@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch 16 times, most recently from f06029b to 74b91b3 Compare April 9, 2026 12:09
@duburcqa
Copy link
Copy Markdown
Contributor Author

duburcqa commented Apr 9, 2026

I was assisted by Claude Opus to write this PR. I have read and review every changes in this PR. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus.

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 74b91b3 to 4bc162b Compare April 9, 2026 12:58
@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 4bc162b to 5aa97b0 Compare April 9, 2026 13:08
@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 5aa97b0 to d48c667 Compare April 9, 2026 14:17
@duburcqa
Copy link
Copy Markdown
Contributor Author

duburcqa commented Apr 9, 2026

I have reviewed this PR, and approve it.

Copy link
Copy Markdown
Collaborator

@hughperkins hughperkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving, as a codeowner, on the basis that Alexis has reviewed and approved this PR."

@duburcqa duburcqa merged commit 79ec049 into main Apr 9, 2026
47 checks passed
@duburcqa duburcqa deleted the duburcqa/fix_metal_atomics_shared_mem branch April 9, 2026 15:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants