[SPIRV] Feature Parity Atomics & Shared Array#432
Conversation
af08498 to
02fb882
Compare
|
I was assisted by Claude Opus to write this PR. I have read every line added in this PR, and reviewed the lines. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus. |
c7cbb8b to
4a42abe
Compare
| auto elem_num = tensor_type->get_num_elements(); | ||
| spirv::SType elem_type = | ||
| ir_->get_primitive_type(tensor_type->get_element_type()); | ||
| DataType elem_dt = tensor_type->get_element_type(); |
There was a problem hiding this comment.
elem_dt and elem_type are very confusing. Could we either give more intuitive names, or at least add a comment on what is the difference between them?
There was a problem hiding this comment.
It should be better now.
| // float atomics). | ||
| if (alloca->is_shared && is_real(elem_dt)) { | ||
| elem_type = | ||
| ir_->get_primitive_type(ir_->get_quadrants_uint_type(elem_dt)); |
There was a problem hiding this comment.
it's not clear to me from the name what get_quadrants_uint_type does. specifically around nubmer of bits. Could we add a comment to clarify what is happening in this line, specifically around nubmer of bits?
There was a problem hiding this comment.
I should be better now.
| spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); | ||
| auto dt = stmt->element_type().ptr_removed(); | ||
| // Flatten nested tensor types to scalar (e.g., vec3 to f32) | ||
| if (auto nested = dt->cast<TensorType>()) { |
There was a problem hiding this comment.
this seems very similar to what happens above. Could this be factorized into a helper function?
There was a problem hiding this comment.
same here as for the AllocaStmt: lets have a minimal footprint in this function, that peels off thinknig about uint-backed float shared arrays into a separate function in aseparate file pelase.
| ir_->get_primitive_type(dt), origin_val.stype.storage_class); | ||
| auto elem_type = ir_->get_primitive_type(dt); | ||
| if (shared_float_retyped_.count(stmt->origin)) { | ||
| elem_type = ir_->get_primitive_type(ir_->get_quadrants_uint_type(dt)); |
There was a problem hiding this comment.
ditto for questoin about helper function.
| spirv::SType ptr_type = ir_->get_pointer_type( | ||
| ir_->get_primitive_type(dt), origin_val.stype.storage_class); | ||
| auto elem_type = ir_->get_primitive_type(dt); | ||
| if (shared_float_retyped_.count(stmt->origin)) { |
There was a problem hiding this comment.
can we add a comment about what this if statement is checking for intuitively
| spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val); | ||
| ptr_val = ir_->add(origin_val, offset_bytes); | ||
| ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; | ||
| } else if (origin_val.stype.flag == TypeKind::kPtr) { |
There was a problem hiding this comment.
can we add a comment about what this new else if block is checking for
There was a problem hiding this comment.
lots more new code. lets mov to the uint backed shared array file please.
There was a problem hiding this comment.
Before this PR, shared arrays never reached the float-atomic path; now they can, so all branches guard at_buffer() with dest_is_ptr ternaries.
| stmt->op_type == AtomicOpType::add) { | ||
| addr_ptr = at_buffer(stmt->dest, dt); | ||
| } else { | ||
| addr_ptr = dest_is_ptr |
There was a problem hiding this comment.
this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.
There was a problem hiding this comment.
why are we refactorziing by the way? It's just extra codde for me to read. Is this reafctoziing necessary?
| } | ||
|
|
||
| // Shared float arrays are retyped to uint, so native float atomics | ||
| // (which require a float pointer) cannot be used on them. |
There was a problem hiding this comment.
I'm not sure I follow. I thought the purpose of changing the backing type to uint was to enable the spirv atomics? Could you give a little more clarification (in the comments) about this point please.
Also, how do we know if we ar edealing with a shared array here?
There was a problem hiding this comment.
Added some comment on the PR itself to clarify this.
| def test_shared_array_float_atomics(op): | ||
| N = 256 | ||
| block_dim = 32 | ||
| total = block_dim * (block_dim - 1) / 2.0 |
There was a problem hiding this comment.
total what? total_threads? Why are we dividing by 2.0? Oh, perhaps we are doing some kind of arithmetic progression or similar, and this is the expected_sum of that progression? Could we update the name to make the meaning more intuitive please. By the way, this calculation could be done using ints. Could we make this something that needs actual floats? Like, e.g. multiply each term in the progression by 0.333, which is pretty incompatible with binary representation.
|
Opus review: Thoughts What's good: Things I'd flag:
Overall, this is a well-structured branch. The Metal/Vulkan shared memory work is the kind of backend plumbing that's easy to get wrong, but the approach here is principled and |
|
from the AI review, pelase could we address at least: |
|
(so AI and myself concur about the ambiguity over what get_quadrants_uint_type does) |
4a42abe to
4b0ea62
Compare
f06029b to
74b91b3
Compare
|
I was assisted by Claude Opus to write this PR. I have read and review every changes in this PR. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus. |
74b91b3 to
4bc162b
Compare
…atch, at_buffer guards.
4bc162b to
5aa97b0
Compare
5aa97b0 to
d48c667
Compare
|
I have reviewed this PR, and approve it. |
hughperkins
left a comment
There was a problem hiding this comment.
Approving, as a codeowner, on the basis that Alexis has reviewed and approved this PR."
Brief Summary