-
Notifications
You must be signed in to change notification settings - Fork 19
[SPIRV] Feature Parity Atomics & Shared Array #432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d028124
6a6f4e1
29c1bd4
37d3eb1
d48c667
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,6 +200,16 @@ class TaskCodegen : public IRVisitor { | |
| std::unordered_map<int, GetRootStmt *> | ||
| root_stmts_; // maps root id to get root stmt | ||
| std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_; | ||
| // Shared float AllocaStmts targeted by atomics, populated by | ||
| // scan_shared_atomic_allocs() before codegen. Value = true means the alloca | ||
| // has non-add ops (CAS unconditionally needed); false = add-only (native | ||
| // shared float atomics can be used if the device supports them). | ||
| std::unordered_map<const Stmt *, bool> shared_float_allocas_with_atomic_rmw_; | ||
| // Propagated from shared_float_allocas_with_atomic_rmw_ to derived | ||
| // MatrixPtrStmt nodes during codegen, so that load/store/atomic visitors | ||
| // know to bitcast. E.g. if `sharr` (AllocaStmt) is retyped, then | ||
| // `sharr[0]` (MatrixPtrStmt) is added here during visit(MatrixPtrStmt). | ||
|
hughperkins marked this conversation as resolved.
|
||
| std::unordered_set<const Stmt *> uint_backed_shared_float_ptr_stmts_; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ✅ |
||
| std::unordered_map<std::vector<int>, Value, hashing::Hasher<std::vector<int>>> | ||
| argid_to_tex_value_; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| #include "quadrants/codegen/spirv/kernel_utils.h" | ||
| #include "quadrants/codegen/spirv/spirv_ir_builder.h" | ||
| #include "quadrants/codegen/spirv/detail/spirv_codegen.h" | ||
| #include "quadrants/codegen/spirv/spirv_shared_array_retyping.h" | ||
| #include "quadrants/ir/transforms.h" | ||
| #include "quadrants/math/arithmetic.h" | ||
| #include "quadrants/codegen/ir_dump.h" | ||
|
|
@@ -119,6 +120,9 @@ TaskCodegen::Result TaskCodegen::run() { | |
| kernel_function_ = ir_->new_function(); // void main(); | ||
| ir_->debug_name(spv::OpName, kernel_function_, "main"); | ||
|
|
||
| scan_shared_atomic_allocs(task_ir_->body.get(), | ||
| shared_float_allocas_with_atomic_rmw_); | ||
|
|
||
| if (task_ir_->task_type == OffloadedTaskType::serial) { | ||
| generate_serial_kernel(task_ir_); | ||
| } else if (task_ir_->task_type == OffloadedTaskType::range_for) { | ||
|
|
@@ -270,11 +274,21 @@ void TaskCodegen::visit(ConstStmt *const_stmt) { | |
|
|
||
| void TaskCodegen::visit(AllocaStmt *alloca) { | ||
| spirv::Value ptr_val; | ||
| // alloca->ret_type is a pointer to the stored type; ptr_removed() gives the | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ✅ |
||
| // stored type itself (e.g. TensorType<32 x f32> for a 32-element array). | ||
| auto alloca_type = alloca->ret_type.ptr_removed(); | ||
| // Shared array is always modeled as a tensor type, i.e. an array of scalars. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ✅ |
||
| if (auto tensor_type = alloca_type->cast<TensorType>()) { | ||
| auto elem_num = tensor_type->get_num_elements(); | ||
| spirv::SType elem_type = | ||
| ir_->get_primitive_type(tensor_type->get_element_type()); | ||
| // Do NOT initialize elem_num/elem_type here - the helper flattens nested | ||
| // tensor types (e.g. vec3 -> 3xf32) before computing them. Pre-initializing | ||
| // with get_primitive_type(tensor_type->get_element_type()) would crash on | ||
| // nested tensor types like Tensor(3) f32. | ||
| int elem_num; | ||
| spirv::SType elem_type; | ||
| maybe_retype_alloca(*ir_, *caps_, alloca, tensor_type, | ||
| shared_float_allocas_with_atomic_rmw_, | ||
| uint_backed_shared_float_ptr_stmts_, elem_num, | ||
| elem_type); | ||
| spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num); | ||
| if (alloca->is_shared) { // for shared memory / workgroup memory | ||
| ptr_val = ir_->alloca_workgroup_array(arr_type); | ||
|
|
@@ -297,12 +311,18 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) { | |
| spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); | ||
| auto dt = stmt->element_type().ptr_removed(); | ||
| if (stmt->offset_used_as_index()) { | ||
| if (stmt->origin->is<AllocaStmt>()) { | ||
| // Origin is a local/shared array allocation or a derived pointer from one | ||
| // - use OpAccessChain or OpPtrAccessChain respectively. | ||
| if (stmt->origin->is<AllocaStmt>() || | ||
| origin_val.stype.flag == TypeKind::kPtr) { | ||
| maybe_retype_derived_ptr(*ir_, stmt->origin, stmt, dt, | ||
| uint_backed_shared_float_ptr_stmts_); | ||
| spirv::SType ptr_type = ir_->get_pointer_type( | ||
| ir_->get_primitive_type(dt), origin_val.stype.storage_class); | ||
| ptr_val = | ||
| ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val); | ||
| if (stmt->origin->as<AllocaStmt>()->is_shared) { | ||
| auto op = stmt->origin->is<AllocaStmt>() ? spv::OpAccessChain | ||
| : spv::OpPtrAccessChain; | ||
| ptr_val = ir_->make_value(op, ptr_type, origin_val, offset_val); | ||
| if (auto *a = stmt->origin->cast<AllocaStmt>(); a && a->is_shared) { | ||
| ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; | ||
| } | ||
| } else if (stmt->origin->is<GlobalTemporaryStmt>()) { | ||
|
|
@@ -324,14 +344,22 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) { | |
| void TaskCodegen::visit(LocalLoadStmt *stmt) { | ||
| auto ptr = stmt->src; | ||
| spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); | ||
| spirv::Value val = ir_->load_variable( | ||
| ptr_val, ir_->get_primitive_type(stmt->element_type())); | ||
| spirv::Value val; | ||
| if (uint_backed_shared_float_ptr_stmts_.count(ptr)) { | ||
| val = load_uint_backed_shared_float(*ir_, ptr_val, stmt->element_type()); | ||
| } else { | ||
| val = ir_->load_variable(ptr_val, | ||
| ir_->get_primitive_type(stmt->element_type())); | ||
| } | ||
| ir_->register_value(stmt->raw_name(), val); | ||
| } | ||
|
|
||
| void TaskCodegen::visit(LocalStoreStmt *stmt) { | ||
| spirv::Value ptr_val = ir_->query_value(stmt->dest->raw_name()); | ||
| spirv::Value val = ir_->query_value(stmt->val->raw_name()); | ||
| if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) { | ||
| val = float_to_shared_uint(*ir_, val, stmt->val->element_type()); | ||
| } | ||
| ir_->store_variable(ptr_val, val); | ||
| } | ||
|
|
||
|
|
@@ -1521,13 +1549,17 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { | |
|
|
||
| spirv::Value addr_ptr; | ||
| spirv::Value dest_val = ir_->query_value(stmt->dest->raw_name()); | ||
| // Shared arrays have already created an accesschain, use it directly. | ||
| // Shared arrays already have a pointer from OpAccessChain (dest_is_ptr=true). | ||
| // at_buffer() looks up ptr_to_buffers_ to find the StorageBuffer and compute | ||
| // a byte offset - shared/workgroup arrays aren't in ptr_to_buffers_, so | ||
| // at_buffer() would fail on them. | ||
| const bool dest_is_ptr = dest_val.stype.flag == TypeKind::kPtr; | ||
|
|
||
| // The native-add branches originally called at_buffer() directly, but shared | ||
| // arrays can now reach this path, so all branches need the dest_is_ptr guard. | ||
| if (dt->is_primitive(PrimitiveTypeID::f64)) { | ||
| if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) && | ||
| stmt->op_type == AtomicOpType::add) { | ||
| addr_ptr = at_buffer(stmt->dest, dt); | ||
| addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the uint backed shared mean we have to change this line?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comment. |
||
| } else { | ||
| addr_ptr = dest_is_ptr | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we refactorziing by the way? It's just extra codde for me to read. Is this reafctoziing necessary?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. |
||
| ? dest_val | ||
|
|
@@ -1536,7 +1568,19 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { | |
| } else if (dt->is_primitive(PrimitiveTypeID::f32)) { | ||
| if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) && | ||
| stmt->op_type == AtomicOpType::add) { | ||
| addr_ptr = at_buffer(stmt->dest, dt); | ||
| addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the uint backed shared mean we have to change this line?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comment. |
||
| } else { | ||
| addr_ptr = dest_is_ptr | ||
| ? dest_val | ||
| : at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); | ||
| } | ||
| } else if (dt->is_primitive(PrimitiveTypeID::f16)) { | ||
| // f16 needs the same uint-typed pointer as f32/f64 for the CAS path. | ||
| // Without this, at_buffer returns pointer-to-f16 but the CAS loop uses | ||
| // OpAtomicLoad(u16, ...) causing a SPIR-V type mismatch. | ||
| if (caps_->get(DeviceCapability::spirv_has_atomic_float16_add) && | ||
| stmt->op_type == AtomicOpType::add) { | ||
| addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); | ||
| } else { | ||
| addr_ptr = dest_is_ptr | ||
| ? dest_val | ||
|
|
@@ -1549,7 +1593,9 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { | |
| auto ret_type = ir_->get_primitive_type(dt); | ||
|
|
||
| if (is_real(dt)) { | ||
| spv::Op atomic_fp_op; | ||
| // Only initialized for add (the only op with native float atomic support). | ||
| // Safe: use_native_atomics is only true when op_type == add. | ||
| spv::Op atomic_fp_op = spv::OpNop; | ||
| if (stmt->op_type == AtomicOpType::add) { | ||
| atomic_fp_op = spv::OpAtomicFAddEXT; | ||
| } | ||
|
|
@@ -1572,11 +1618,26 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { | |
| use_native_atomics = true; | ||
| } | ||
| } | ||
| // The checks above use buffer capabilities. For shared pointers, override | ||
| // with shared capabilities (buffer and shared support are independent). | ||
| if (dest_is_ptr && stmt->op_type == AtomicOpType::add) { | ||
| use_native_atomics = has_native_float_atomic_add(*caps_, dt, true); | ||
| } | ||
| // Uint-retyped shared arrays have a uint pointer - native float atomics | ||
| // would produce invalid SPIR-V on them. | ||
| if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) { | ||
| use_native_atomics = false; | ||
| } | ||
|
|
||
| if (use_native_atomics) { | ||
| val = ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr, | ||
| /*scope=*/ir_->const_i32_one_, | ||
| /*semantics=*/ir_->const_i32_zero_, data); | ||
| } else if (dest_is_ptr) { | ||
| // Shared float arrays use uint-backed CAS (width-aware for f16->u32). | ||
| // Integer shared atomics don't need this - they use native OpAtomicIAdd | ||
| // etc. directly on the shared pointer. | ||
| val = shared_float_atomic(*ir_, stmt->op_type, addr_ptr, data, dt); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is dest_is_ptr not true for shared int atomics?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comment. |
||
| } else { | ||
| val = ir_->float_atomic(stmt->op_type, addr_ptr, data, dt); | ||
| } | ||
|
|
@@ -1590,7 +1651,11 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { | |
| op = spv::OpAtomicISub; | ||
| use_native_atomics = true; | ||
| } else if (stmt->op_type == AtomicOpType::mul) { | ||
| addr_ptr = at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); | ||
| // dest_is_ptr guard needed here too - at_buffer would crash on shared | ||
| // integer arrays (same reason as the float branches above). | ||
| addr_ptr = dest_is_ptr | ||
| ? dest_val | ||
| : at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); | ||
| val = ir_->integer_atomic(stmt->op_type, addr_ptr, data, dt); | ||
| use_native_atomics = false; | ||
| } else if (stmt->op_type == AtomicOpType::min) { | ||
|
|
@@ -2096,6 +2161,9 @@ void TaskCodegen::generate_struct_for_kernel(OffloadedStmt *stmt) { | |
| task_attribs_.buffer_binds = get_buffer_binds(); | ||
| } | ||
|
|
||
| // Return the address in device memory for a global/storage-buffer access. | ||
| // Only works for device-buffer-backed pointers (via ptr_to_buffers_), not | ||
| // workgroup arrays - those already have a pointer from OpAccessChain. | ||
| spirv::Value TaskCodegen::at_buffer(const Stmt *ptr, DataType dt) { | ||
| spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,13 @@ void IRBuilder::init_header() { | |
| if (caps_->get(cap::spirv_has_int64)) { | ||
| ib_.begin(spv::OpCapability).add(spv::CapabilityInt64).commit(&header_); | ||
| } | ||
| if (caps_->get(cap::spirv_has_atomic_int64)) { | ||
| // Required for OpAtomicLoad/OpAtomicCompareExchange on u64, used by | ||
| // the CAS-based f64 shared float atomic emulation path. | ||
| ib_.begin(spv::OpCapability) | ||
| .add(spv::CapabilityInt64Atomics) | ||
| .commit(&header_); | ||
| } | ||
| if (caps_->get(cap::spirv_has_float16)) { | ||
| ib_.begin(spv::OpCapability).add(spv::CapabilityFloat16).commit(&header_); | ||
| } | ||
|
|
@@ -1075,6 +1082,8 @@ Value IRBuilder::float_atomic(AtomicOpType op_type, | |
| Value addr_ptr, | ||
| Value data, | ||
| const DataType &dt) { | ||
| // Use dt-derived type instead of t_fp32_ so FMin/FMax work for f16/f64. | ||
| auto float_type = get_primitive_type(dt); | ||
| if (op_type == AtomicOpType::add) { | ||
| return atomic_operation( | ||
| addr_ptr, data, [&](Value lhs, Value rhs) { return add(lhs, rhs); }, | ||
|
|
@@ -1091,14 +1100,14 @@ Value IRBuilder::float_atomic(AtomicOpType op_type, | |
| return atomic_operation( | ||
| addr_ptr, data, | ||
| [&](Value lhs, Value rhs) { | ||
| return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs); | ||
| return call_glsl450(float_type, /*FMin*/ 37, lhs, rhs); | ||
|
hughperkins marked this conversation as resolved.
|
||
| }, | ||
| dt); | ||
| } else if (op_type == AtomicOpType::max) { | ||
| return atomic_operation( | ||
| addr_ptr, data, | ||
| [&](Value lhs, Value rhs) { | ||
| return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs); | ||
| return call_glsl450(float_type, /*FMax*/ 40, lhs, rhs); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ✅ |
||
| }, | ||
| dt); | ||
| } else { | ||
|
|
@@ -1124,7 +1133,13 @@ Value IRBuilder::atomic_operation(Value addr_ptr, | |
| std::function<Value(Value, Value)> op, | ||
| const DataType &dt) { | ||
| SType out_type = get_primitive_type(dt); | ||
| SType res_type = get_primitive_uint_type(dt); | ||
| // Device-buffer pointers are uint-typed (from at_buffer), so CAS uses uint. | ||
| // Workgroup (shared) pointers keep their original type (e.g. i32). Using uint | ||
| // on a signed pointer causes Metal's atomic_compare_exchange to reject the | ||
| // shader due to signed/unsigned type mismatch. | ||
| const bool is_workgroup = | ||
| addr_ptr.stype.storage_class == spv::StorageClassWorkgroup; | ||
| SType res_type = is_workgroup ? out_type : get_primitive_uint_type(dt); | ||
| Value ret_val_int = alloca_variable(res_type); | ||
|
|
||
| // do-while | ||
|
|
@@ -1150,10 +1165,15 @@ Value IRBuilder::atomic_operation(Value addr_ptr, | |
| Value old_val = make_value(spv::OpAtomicLoad, res_type, addr_ptr, | ||
| /*scope=*/const_i32_one_, | ||
| /*semantics=*/const_i32_zero_); | ||
| // int new = dataTypeBitsToInt(atomic_op(intBitsToDataType(old), data)); | ||
| Value old_data_value = make_value(spv::OpBitcast, out_type, old_val); | ||
| // Bitcast uint<->float for the operation. Skip when types already match | ||
| // (integer workgroup path where res_type == out_type). | ||
| Value old_data_value = (out_type.id != res_type.id) | ||
| ? make_value(spv::OpBitcast, out_type, old_val) | ||
| : old_val; | ||
| Value new_data_value = op(old_data_value, data); | ||
| Value new_val = make_value(spv::OpBitcast, res_type, new_data_value); | ||
| Value new_val = (out_type.id != res_type.id) | ||
| ? make_value(spv::OpBitcast, res_type, new_data_value) | ||
| : new_data_value; | ||
| // int loaded = atomicCompSwap(vals[0], old, new); | ||
| /* | ||
| * Don't need this part, theoretically | ||
|
|
@@ -1188,8 +1208,10 @@ Value IRBuilder::atomic_operation(Value addr_ptr, | |
| } | ||
| start_label(exit); | ||
|
|
||
| return make_value(spv::OpBitcast, out_type, | ||
| load_variable(ret_val_int, res_type)); | ||
| Value ret_loaded = load_variable(ret_val_int, res_type); | ||
| return (out_type.id != res_type.id) | ||
| ? make_value(spv::OpBitcast, out_type, ret_loaded) | ||
| : ret_loaded; | ||
| } | ||
|
|
||
| Value IRBuilder::rand_u32(Value global_tmp_) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
✅