diff --git a/quadrants/codegen/spirv/CMakeLists.txt b/quadrants/codegen/spirv/CMakeLists.txt index 22d311e97b..9449120fe0 100644 --- a/quadrants/codegen/spirv/CMakeLists.txt +++ b/quadrants/codegen/spirv/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources(spirv_codegen kernel_utils.cpp snode_struct_compiler.cpp spirv_codegen.cpp + spirv_shared_array_retyping.cpp spirv_ir_builder.cpp spirv_types.cpp compiled_kernel_data.cpp diff --git a/quadrants/codegen/spirv/detail/spirv_codegen.h b/quadrants/codegen/spirv/detail/spirv_codegen.h index b850d86b79..8f24764ee4 100644 --- a/quadrants/codegen/spirv/detail/spirv_codegen.h +++ b/quadrants/codegen/spirv/detail/spirv_codegen.h @@ -200,6 +200,16 @@ class TaskCodegen : public IRVisitor { std::unordered_map root_stmts_; // maps root id to get root stmt std::unordered_map ptr_to_buffers_; + // Shared float AllocaStmts targeted by atomics, populated by + // scan_shared_atomic_allocs() before codegen. Value = true means the alloca + // has non-add ops (CAS unconditionally needed); false = add-only (native + // shared float atomics can be used if the device supports them). + std::unordered_map shared_float_allocas_with_atomic_rmw_; + // Propagated from shared_float_allocas_with_atomic_rmw_ to derived + // MatrixPtrStmt nodes during codegen, so that load/store/atomic visitors + // know to bitcast. E.g. if `sharr` (AllocaStmt) is retyped, then + // `sharr[0]` (MatrixPtrStmt) is added here during visit(MatrixPtrStmt). + std::unordered_set uint_backed_shared_float_ptr_stmts_; std::unordered_map, Value, hashing::Hasher>> argid_to_tex_value_; diff --git a/quadrants/codegen/spirv/spirv_codegen.cpp b/quadrants/codegen/spirv/spirv_codegen.cpp index ec37e836dc..31f9a6a435 100644 --- a/quadrants/codegen/spirv/spirv_codegen.cpp +++ b/quadrants/codegen/spirv/spirv_codegen.cpp @@ -14,6 +14,7 @@ #include "quadrants/codegen/spirv/kernel_utils.h" #include "quadrants/codegen/spirv/spirv_ir_builder.h" #include "quadrants/codegen/spirv/detail/spirv_codegen.h" +#include "quadrants/codegen/spirv/spirv_shared_array_retyping.h" #include "quadrants/ir/transforms.h" #include "quadrants/math/arithmetic.h" #include "quadrants/codegen/ir_dump.h" @@ -119,6 +120,9 @@ TaskCodegen::Result TaskCodegen::run() { kernel_function_ = ir_->new_function(); // void main(); ir_->debug_name(spv::OpName, kernel_function_, "main"); + scan_shared_atomic_allocs(task_ir_->body.get(), + shared_float_allocas_with_atomic_rmw_); + if (task_ir_->task_type == OffloadedTaskType::serial) { generate_serial_kernel(task_ir_); } else if (task_ir_->task_type == OffloadedTaskType::range_for) { @@ -270,11 +274,21 @@ void TaskCodegen::visit(ConstStmt *const_stmt) { void TaskCodegen::visit(AllocaStmt *alloca) { spirv::Value ptr_val; + // alloca->ret_type is a pointer to the stored type; ptr_removed() gives the + // stored type itself (e.g. TensorType<32 x f32> for a 32-element array). auto alloca_type = alloca->ret_type.ptr_removed(); + // Shared array is always modeled as a tensor type, i.e. an array of scalars. if (auto tensor_type = alloca_type->cast()) { - auto elem_num = tensor_type->get_num_elements(); - spirv::SType elem_type = - ir_->get_primitive_type(tensor_type->get_element_type()); + // Do NOT initialize elem_num/elem_type here - the helper flattens nested + // tensor types (e.g. vec3 -> 3xf32) before computing them. Pre-initializing + // with get_primitive_type(tensor_type->get_element_type()) would crash on + // nested tensor types like Tensor(3) f32. + int elem_num; + spirv::SType elem_type; + maybe_retype_alloca(*ir_, *caps_, alloca, tensor_type, + shared_float_allocas_with_atomic_rmw_, + uint_backed_shared_float_ptr_stmts_, elem_num, + elem_type); spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num); if (alloca->is_shared) { // for shared memory / workgroup memory ptr_val = ir_->alloca_workgroup_array(arr_type); @@ -297,12 +311,18 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) { spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); auto dt = stmt->element_type().ptr_removed(); if (stmt->offset_used_as_index()) { - if (stmt->origin->is()) { + // Origin is a local/shared array allocation or a derived pointer from one + // - use OpAccessChain or OpPtrAccessChain respectively. + if (stmt->origin->is() || + origin_val.stype.flag == TypeKind::kPtr) { + maybe_retype_derived_ptr(*ir_, stmt->origin, stmt, dt, + uint_backed_shared_float_ptr_stmts_); spirv::SType ptr_type = ir_->get_pointer_type( ir_->get_primitive_type(dt), origin_val.stype.storage_class); - ptr_val = - ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val); - if (stmt->origin->as()->is_shared) { + auto op = stmt->origin->is() ? spv::OpAccessChain + : spv::OpPtrAccessChain; + ptr_val = ir_->make_value(op, ptr_type, origin_val, offset_val); + if (auto *a = stmt->origin->cast(); a && a->is_shared) { ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; } } else if (stmt->origin->is()) { @@ -324,14 +344,22 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) { void TaskCodegen::visit(LocalLoadStmt *stmt) { auto ptr = stmt->src; spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); - spirv::Value val = ir_->load_variable( - ptr_val, ir_->get_primitive_type(stmt->element_type())); + spirv::Value val; + if (uint_backed_shared_float_ptr_stmts_.count(ptr)) { + val = load_uint_backed_shared_float(*ir_, ptr_val, stmt->element_type()); + } else { + val = ir_->load_variable(ptr_val, + ir_->get_primitive_type(stmt->element_type())); + } ir_->register_value(stmt->raw_name(), val); } void TaskCodegen::visit(LocalStoreStmt *stmt) { spirv::Value ptr_val = ir_->query_value(stmt->dest->raw_name()); spirv::Value val = ir_->query_value(stmt->val->raw_name()); + if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) { + val = float_to_shared_uint(*ir_, val, stmt->val->element_type()); + } ir_->store_variable(ptr_val, val); } @@ -1521,13 +1549,17 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { spirv::Value addr_ptr; spirv::Value dest_val = ir_->query_value(stmt->dest->raw_name()); - // Shared arrays have already created an accesschain, use it directly. + // Shared arrays already have a pointer from OpAccessChain (dest_is_ptr=true). + // at_buffer() looks up ptr_to_buffers_ to find the StorageBuffer and compute + // a byte offset - shared/workgroup arrays aren't in ptr_to_buffers_, so + // at_buffer() would fail on them. const bool dest_is_ptr = dest_val.stype.flag == TypeKind::kPtr; - + // The native-add branches originally called at_buffer() directly, but shared + // arrays can now reach this path, so all branches need the dest_is_ptr guard. if (dt->is_primitive(PrimitiveTypeID::f64)) { if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) && stmt->op_type == AtomicOpType::add) { - addr_ptr = at_buffer(stmt->dest, dt); + addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); } else { addr_ptr = dest_is_ptr ? dest_val @@ -1536,7 +1568,19 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { } else if (dt->is_primitive(PrimitiveTypeID::f32)) { if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) && stmt->op_type == AtomicOpType::add) { - addr_ptr = at_buffer(stmt->dest, dt); + addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); + } else { + addr_ptr = dest_is_ptr + ? dest_val + : at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); + } + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + // f16 needs the same uint-typed pointer as f32/f64 for the CAS path. + // Without this, at_buffer returns pointer-to-f16 but the CAS loop uses + // OpAtomicLoad(u16, ...) causing a SPIR-V type mismatch. + if (caps_->get(DeviceCapability::spirv_has_atomic_float16_add) && + stmt->op_type == AtomicOpType::add) { + addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt); } else { addr_ptr = dest_is_ptr ? dest_val @@ -1549,7 +1593,9 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { auto ret_type = ir_->get_primitive_type(dt); if (is_real(dt)) { - spv::Op atomic_fp_op; + // Only initialized for add (the only op with native float atomic support). + // Safe: use_native_atomics is only true when op_type == add. + spv::Op atomic_fp_op = spv::OpNop; if (stmt->op_type == AtomicOpType::add) { atomic_fp_op = spv::OpAtomicFAddEXT; } @@ -1572,11 +1618,26 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { use_native_atomics = true; } } + // The checks above use buffer capabilities. For shared pointers, override + // with shared capabilities (buffer and shared support are independent). + if (dest_is_ptr && stmt->op_type == AtomicOpType::add) { + use_native_atomics = has_native_float_atomic_add(*caps_, dt, true); + } + // Uint-retyped shared arrays have a uint pointer - native float atomics + // would produce invalid SPIR-V on them. + if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) { + use_native_atomics = false; + } if (use_native_atomics) { val = ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr, /*scope=*/ir_->const_i32_one_, /*semantics=*/ir_->const_i32_zero_, data); + } else if (dest_is_ptr) { + // Shared float arrays use uint-backed CAS (width-aware for f16->u32). + // Integer shared atomics don't need this - they use native OpAtomicIAdd + // etc. directly on the shared pointer. + val = shared_float_atomic(*ir_, stmt->op_type, addr_ptr, data, dt); } else { val = ir_->float_atomic(stmt->op_type, addr_ptr, data, dt); } @@ -1590,7 +1651,11 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) { op = spv::OpAtomicISub; use_native_atomics = true; } else if (stmt->op_type == AtomicOpType::mul) { - addr_ptr = at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); + // dest_is_ptr guard needed here too - at_buffer would crash on shared + // integer arrays (same reason as the float branches above). + addr_ptr = dest_is_ptr + ? dest_val + : at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt)); val = ir_->integer_atomic(stmt->op_type, addr_ptr, data, dt); use_native_atomics = false; } else if (stmt->op_type == AtomicOpType::min) { @@ -2096,6 +2161,9 @@ void TaskCodegen::generate_struct_for_kernel(OffloadedStmt *stmt) { task_attribs_.buffer_binds = get_buffer_binds(); } +// Return the address in device memory for a global/storage-buffer access. +// Only works for device-buffer-backed pointers (via ptr_to_buffers_), not +// workgroup arrays - those already have a pointer from OpAccessChain. spirv::Value TaskCodegen::at_buffer(const Stmt *ptr, DataType dt) { spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); diff --git a/quadrants/codegen/spirv/spirv_ir_builder.cpp b/quadrants/codegen/spirv/spirv_ir_builder.cpp index c1994b599d..978a5019d1 100644 --- a/quadrants/codegen/spirv/spirv_ir_builder.cpp +++ b/quadrants/codegen/spirv/spirv_ir_builder.cpp @@ -63,6 +63,13 @@ void IRBuilder::init_header() { if (caps_->get(cap::spirv_has_int64)) { ib_.begin(spv::OpCapability).add(spv::CapabilityInt64).commit(&header_); } + if (caps_->get(cap::spirv_has_atomic_int64)) { + // Required for OpAtomicLoad/OpAtomicCompareExchange on u64, used by + // the CAS-based f64 shared float atomic emulation path. + ib_.begin(spv::OpCapability) + .add(spv::CapabilityInt64Atomics) + .commit(&header_); + } if (caps_->get(cap::spirv_has_float16)) { ib_.begin(spv::OpCapability).add(spv::CapabilityFloat16).commit(&header_); } @@ -1075,6 +1082,8 @@ Value IRBuilder::float_atomic(AtomicOpType op_type, Value addr_ptr, Value data, const DataType &dt) { + // Use dt-derived type instead of t_fp32_ so FMin/FMax work for f16/f64. + auto float_type = get_primitive_type(dt); if (op_type == AtomicOpType::add) { return atomic_operation( addr_ptr, data, [&](Value lhs, Value rhs) { return add(lhs, rhs); }, @@ -1091,14 +1100,14 @@ Value IRBuilder::float_atomic(AtomicOpType op_type, return atomic_operation( addr_ptr, data, [&](Value lhs, Value rhs) { - return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs); + return call_glsl450(float_type, /*FMin*/ 37, lhs, rhs); }, dt); } else if (op_type == AtomicOpType::max) { return atomic_operation( addr_ptr, data, [&](Value lhs, Value rhs) { - return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs); + return call_glsl450(float_type, /*FMax*/ 40, lhs, rhs); }, dt); } else { @@ -1124,7 +1133,13 @@ Value IRBuilder::atomic_operation(Value addr_ptr, std::function op, const DataType &dt) { SType out_type = get_primitive_type(dt); - SType res_type = get_primitive_uint_type(dt); + // Device-buffer pointers are uint-typed (from at_buffer), so CAS uses uint. + // Workgroup (shared) pointers keep their original type (e.g. i32). Using uint + // on a signed pointer causes Metal's atomic_compare_exchange to reject the + // shader due to signed/unsigned type mismatch. + const bool is_workgroup = + addr_ptr.stype.storage_class == spv::StorageClassWorkgroup; + SType res_type = is_workgroup ? out_type : get_primitive_uint_type(dt); Value ret_val_int = alloca_variable(res_type); // do-while @@ -1150,10 +1165,15 @@ Value IRBuilder::atomic_operation(Value addr_ptr, Value old_val = make_value(spv::OpAtomicLoad, res_type, addr_ptr, /*scope=*/const_i32_one_, /*semantics=*/const_i32_zero_); - // int new = dataTypeBitsToInt(atomic_op(intBitsToDataType(old), data)); - Value old_data_value = make_value(spv::OpBitcast, out_type, old_val); + // Bitcast uint<->float for the operation. Skip when types already match + // (integer workgroup path where res_type == out_type). + Value old_data_value = (out_type.id != res_type.id) + ? make_value(spv::OpBitcast, out_type, old_val) + : old_val; Value new_data_value = op(old_data_value, data); - Value new_val = make_value(spv::OpBitcast, res_type, new_data_value); + Value new_val = (out_type.id != res_type.id) + ? make_value(spv::OpBitcast, res_type, new_data_value) + : new_data_value; // int loaded = atomicCompSwap(vals[0], old, new); /* * Don't need this part, theoretically @@ -1188,8 +1208,10 @@ Value IRBuilder::atomic_operation(Value addr_ptr, } start_label(exit); - return make_value(spv::OpBitcast, out_type, - load_variable(ret_val_int, res_type)); + Value ret_loaded = load_variable(ret_val_int, res_type); + return (out_type.id != res_type.id) + ? make_value(spv::OpBitcast, out_type, ret_loaded) + : ret_loaded; } Value IRBuilder::rand_u32(Value global_tmp_) { diff --git a/quadrants/codegen/spirv/spirv_shared_array_retyping.cpp b/quadrants/codegen/spirv/spirv_shared_array_retyping.cpp new file mode 100644 index 0000000000..1606620dcc --- /dev/null +++ b/quadrants/codegen/spirv/spirv_shared_array_retyping.cpp @@ -0,0 +1,289 @@ +// Note: this module operates only on Quadrants IR types (AllocaStmt, Block, +// etc.) and SPIR-V types - no LLVM types are involved. + +#include "quadrants/codegen/spirv/spirv_shared_array_retyping.h" + +#include "quadrants/ir/type_utils.h" + +namespace quadrants::lang { +namespace spirv { +namespace { + +// Follow MatrixPtrStmt::origin chains back to the source AllocaStmt. +// Assumes only MatrixPtrStmt's on the chain to get to the AllocStmt. +const AllocaStmt *trace_to_alloca(const Stmt *stmt) { + if (auto *alloca = stmt->cast()) + return alloca; + if (auto *matrix_ptr = stmt->cast()) + return trace_to_alloca(matrix_ptr->origin); + return nullptr; +} + +DataType get_atomic_uint_dtype(IRBuilder &ir, const DataType &dt) { + DataType uint_dt = ir.get_quadrants_uint_type(dt); + if (uint_dt == PrimitiveType::u16 || uint_dt == PrimitiveType::u8) { + return PrimitiveType::u32; + } + return uint_dt; +} + +// CAS loop with width-aware uint<->float conversion. The atomic backing type +// (res_type) may be wider than the float type (e.g. u32 for f16), so +// OpUConvert narrows/widens around the bitcasts. +Value atomic_operation_widened(IRBuilder &ir, + Value addr_ptr, + Value data, + std::function op, + const DataType &dt, + const DataType &atomic_uint_dt) { + SType float_type = ir.get_primitive_type(dt); + SType narrow_uint = ir.get_primitive_uint_type(dt); + SType res_type = ir.get_primitive_type(atomic_uint_dt); + Value ret_val_int = ir.alloca_variable(res_type); + + Label head = ir.new_label(); + Label body = ir.new_label(); + Label branch_true = ir.new_label(); + Label branch_false = ir.new_label(); + Label merge = ir.new_label(); + Label exit = ir.new_label(); + + ir.make_inst(spv::OpBranch, head); + ir.start_label(head); + ir.make_inst(spv::OpLoopMerge, branch_true, merge, 0); + ir.make_inst(spv::OpBranch, body); + ir.make_inst(spv::OpLabel, body); + { + // See IRBuilder::atomic_operation for why OpAtomicLoad is used here + // instead of OpLoad (prevents SPIRV-Cross inlining on Metal). + Value old_val = ir.make_value(spv::OpAtomicLoad, res_type, addr_ptr, + /*scope=*/ir.const_i32_one_, + /*semantics=*/ir.const_i32_zero_); + // uint -> float (narrowing if atomic type is wider) + Value old_narrow = old_val; + if (res_type.id != narrow_uint.id) { + old_narrow = ir.make_value(spv::OpUConvert, narrow_uint, old_val); + } + Value old_data_value = + ir.make_value(spv::OpBitcast, float_type, old_narrow); + Value new_data_value = op(old_data_value, data); + // float -> uint (widening if needed) + Value new_val = ir.make_value(spv::OpBitcast, narrow_uint, new_data_value); + if (res_type.id != narrow_uint.id) { + new_val = ir.make_value(spv::OpUConvert, res_type, new_val); + } + Value loaded = ir.make_value( + spv::OpAtomicCompareExchange, res_type, addr_ptr, + /*scope=*/ir.const_i32_one_, /*semantics if equal=*/ir.const_i32_zero_, + /*semantics if unequal=*/ir.const_i32_zero_, new_val, old_val); + Value ok = ir.make_value(spv::OpIEqual, ir.bool_type(), loaded, old_val); + ir.store_variable(ret_val_int, loaded); + ir.make_inst(spv::OpSelectionMerge, branch_false, 0); + ir.make_inst(spv::OpBranchConditional, ok, branch_true, branch_false); + { + ir.make_inst(spv::OpLabel, branch_true); + ir.make_inst(spv::OpBranch, exit); + } + { + ir.make_inst(spv::OpLabel, branch_false); + ir.make_inst(spv::OpBranch, merge); + } + ir.make_inst(spv::OpLabel, merge); + ir.make_inst(spv::OpBranch, head); + } + ir.start_label(exit); + + Value ret_loaded = ir.load_variable(ret_val_int, res_type); + if (res_type.id != narrow_uint.id) { + ret_loaded = ir.make_value(spv::OpUConvert, narrow_uint, ret_loaded); + } + return ir.make_value(spv::OpBitcast, float_type, ret_loaded); +} + +} // namespace + +void scan_shared_atomic_allocs(Block *ir_block, + std::unordered_map &out) { + for (auto &s : ir_block->statements) { + if (auto *atomic_stmt = s->cast()) { + if (auto *alloca = trace_to_alloca(atomic_stmt->dest)) { + if (alloca->is_shared) { + // alloca->ret_type is a pointer to the stored type; + // ptr_removed() gives the stored type (e.g. array of 128 floats). + auto alloca_dtype = alloca->ret_type.ptr_removed(); + // Shared array is always modeled as a tensor type. + if (auto *tensor_type = alloca_dtype->cast()) { + auto scalar_dtype = tensor_type->get_element_type(); + if (auto *nested = scalar_dtype->cast()) { + scalar_dtype = nested->get_element_type(); + QD_ASSERT_INFO( + !scalar_dtype->cast(), + "Nested tensor types deeper than 2 levels not supported"); + } + if (is_real(scalar_dtype)) { + bool has_non_add = (atomic_stmt->op_type != AtomicOpType::add); + auto [it, inserted] = out.emplace(alloca, has_non_add); + if (!inserted) + it->second = it->second || has_non_add; + } + } + } + } + } + // Recurse into sub-blocks. + // StructForStmt and MeshForStmt are lowered before codegen. + QD_ASSERT(!s->cast()); + QD_ASSERT(!s->cast()); + if (auto *if_stmt = s->cast()) { + if (if_stmt->true_statements) + scan_shared_atomic_allocs(if_stmt->true_statements.get(), out); + if (if_stmt->false_statements) + scan_shared_atomic_allocs(if_stmt->false_statements.get(), out); + } else if (auto *range_for = s->cast()) { + scan_shared_atomic_allocs(range_for->body.get(), out); + } else if (auto *while_stmt = s->cast()) { + scan_shared_atomic_allocs(while_stmt->body.get(), out); + } + } +} + +// Callers must NOT pre-initialize elem_num/elem_type - this function handles +// nested tensor flattening (e.g. array of vec3 -> flat array of f32) which must +// happen before get_primitive_type is called on the element dtype. +void maybe_retype_alloca( + IRBuilder &ir, + const DeviceCapabilityConfig &caps, + const AllocaStmt *alloca, + const TensorType *tensor_type, + const std::unordered_map &alloc_map, + std::unordered_set &retyped_stmts, + int &elem_num, + SType &elem_type) { + elem_num = tensor_type->get_num_elements(); + DataType scalar_dtype = tensor_type->get_element_type(); + // Flatten nested tensor types (e.g., array of vec3 to flat array of f32) + if (auto nested = scalar_dtype->cast()) { + elem_num *= nested->get_num_elements(); + scalar_dtype = nested->get_element_type(); + QD_ASSERT_INFO(!scalar_dtype->cast(), + "Nested tensor types deeper than 2 levels not supported"); + } + elem_type = ir.get_primitive_type(scalar_dtype); + // Retype to uint if this alloca is targeted by float atomics and the device + // lacks native shared float atomic support for all ops used. + auto it = alloc_map.find(alloca); + if (it != alloc_map.end()) { + bool needs_cas = it->second; + if (needs_cas || !has_native_float_atomic_add(caps, scalar_dtype, true)) { + elem_type = + ir.get_primitive_type(get_atomic_uint_dtype(ir, scalar_dtype)); + retyped_stmts.insert(alloca); + } + } +} + +void maybe_retype_derived_ptr(IRBuilder &ir, + const Stmt *origin, + const Stmt *stmt, + DataType &dt, + std::unordered_set &retyped_stmts) { + // Flatten nested tensor types to scalar (e.g., vec3 to f32). + // This must happen for ALL shared array pointers, not just retyped ones. + // Note: this only changes the SPIR-V element type for the access chain, not + // the index. The frontend IR (make_tensor_access_single_element) already + // emits flat scalar indices (i*vec_size+component), so there is no stride + // mismatch despite the storage being flattened in maybe_retype_alloca. + if (auto nested = dt->cast()) { + dt = nested->get_element_type(); + } + if (retyped_stmts.count(origin)) { + dt = get_atomic_uint_dtype(ir, dt); + retyped_stmts.insert(stmt); + } +} + +Value load_uint_backed_shared_float(IRBuilder &ir, + Value ptr_val, + const DataType &element_type) { + auto shared_type = + ir.get_primitive_type(get_atomic_uint_dtype(ir, element_type)); + Value val = ir.load_variable(ptr_val, shared_type); + SType narrow_uint = ir.get_primitive_uint_type(element_type); + if (shared_type.id != narrow_uint.id) { + val = ir.make_value(spv::OpUConvert, narrow_uint, val); + } + return ir.make_value(spv::OpBitcast, ir.get_primitive_type(element_type), + val); +} + +Value float_to_shared_uint(IRBuilder &ir, Value val, const DataType &dt) { + SType narrow_uint = ir.get_primitive_uint_type(dt); + val = ir.make_value(spv::OpBitcast, narrow_uint, val); + SType atomic_uint = ir.get_primitive_type(get_atomic_uint_dtype(ir, dt)); + if (atomic_uint.id != narrow_uint.id) { + val = ir.make_value(spv::OpUConvert, atomic_uint, val); + } + return val; +} + +Value shared_float_atomic(IRBuilder &ir, + AtomicOpType op_type, + Value addr_ptr, + Value data, + const DataType &dt) { + auto atomic_uint_dt = get_atomic_uint_dtype(ir, dt); + auto float_type = ir.get_primitive_type(dt); + if (op_type == AtomicOpType::add) { + return atomic_operation_widened( + ir, addr_ptr, data, + [&](Value lhs, Value rhs) { return ir.add(lhs, rhs); }, dt, + atomic_uint_dt); + } else if (op_type == AtomicOpType::sub) { + return atomic_operation_widened( + ir, addr_ptr, data, + [&](Value lhs, Value rhs) { return ir.sub(lhs, rhs); }, dt, + atomic_uint_dt); + } else if (op_type == AtomicOpType::mul) { + return atomic_operation_widened( + ir, addr_ptr, data, + [&](Value lhs, Value rhs) { return ir.mul(lhs, rhs); }, dt, + atomic_uint_dt); + } else if (op_type == AtomicOpType::min) { + return atomic_operation_widened( + ir, addr_ptr, data, + [&](Value lhs, Value rhs) { + return ir.call_glsl450(float_type, /*FMin*/ 37, lhs, rhs); + }, + dt, atomic_uint_dt); + } else if (op_type == AtomicOpType::max) { + return atomic_operation_widened( + ir, addr_ptr, data, + [&](Value lhs, Value rhs) { + return ir.call_glsl450(float_type, /*FMax*/ 40, lhs, rhs); + }, + dt, atomic_uint_dt); + } else { + QD_NOT_IMPLEMENTED + } +} + +bool has_native_float_atomic_add(const DeviceCapabilityConfig &caps, + const DataType &dt, + bool is_shared) { + if (dt->is_primitive(PrimitiveTypeID::f32)) + return caps.get(is_shared + ? DeviceCapability::spirv_has_shared_atomic_float_add + : DeviceCapability::spirv_has_atomic_float_add); + if (dt->is_primitive(PrimitiveTypeID::f64)) + return caps.get(is_shared + ? DeviceCapability::spirv_has_shared_atomic_float64_add + : DeviceCapability::spirv_has_atomic_float64_add); + if (dt->is_primitive(PrimitiveTypeID::f16)) + return caps.get(is_shared + ? DeviceCapability::spirv_has_shared_atomic_float16_add + : DeviceCapability::spirv_has_atomic_float16_add); + return false; +} + +} // namespace spirv +} // namespace quadrants::lang diff --git a/quadrants/codegen/spirv/spirv_shared_array_retyping.h b/quadrants/codegen/spirv/spirv_shared_array_retyping.h new file mode 100644 index 0000000000..3009d1a14c --- /dev/null +++ b/quadrants/codegen/spirv/spirv_shared_array_retyping.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include "quadrants/ir/statements.h" +#include "quadrants/codegen/spirv/spirv_ir_builder.h" +#include "quadrants/rhi/device.h" + +namespace quadrants::lang { +namespace spirv { + +// Pre-scan the IR block tree to find shared float AllocaStmts targeted by +// atomic operations. These arrays may need uint-backing so that CAS-based +// atomic emulation can use integer atomics (Metal/MoltenVK lack threadgroup +// float atomics). +// +// out[alloca] = true -> has non-add atomic ops, CAS needed unconditionally +// out[alloca] = false -> only add ops, native shared float atomics can be used +// if the device supports them +void scan_shared_atomic_allocs(Block *ir_block, + std::unordered_map &out); + +// Initialize elem_num and elem_type from tensor_type, flattening nested tensor +// types (e.g. vec3 -> 3xf32) and retyping to uint for CAS-based atomics when +// the alloca is targeted by float atomic operations. +void maybe_retype_alloca( + IRBuilder &ir, + const DeviceCapabilityConfig &caps, + const AllocaStmt *alloca, + const TensorType *tensor_type, + const std::unordered_map &alloc_map, + std::unordered_set &retyped_stmts, + int &elem_num, + SType &elem_type); + +// If origin is in retyped_stmts, propagate retyping to stmt and change dt +// to the uint-backed DataType (flattening nested tensor types first). +// Otherwise just flatten dt if it is a nested tensor type. +void maybe_retype_derived_ptr(IRBuilder &ir, + const Stmt *origin, + const Stmt *stmt, + DataType &dt, + std::unordered_set &retyped_stmts); + +// Load from a uint-backed shared float pointer: loads as uint, bitcasts to +// float. Only call when ptr is known to be in retyped_stmts. +Value load_uint_backed_shared_float(IRBuilder &ir, + Value ptr_val, + const DataType &element_type); + +// Convert a float value to uint for storing into a uint-backed shared array. +// Only call when dest is known to be in retyped_stmts. +Value float_to_shared_uint(IRBuilder &ir, Value val, const DataType &dt); + +// CAS-based float atomic for shared (workgroup) arrays. Unlike +// IRBuilder::float_atomic, this handles width-mismatched uint backing +// (e.g. u32 backing for f16 arrays, since Metal/Vulkan lack 16-bit atomics). +Value shared_float_atomic(IRBuilder &ir, + AtomicOpType op_type, + Value addr_ptr, + Value data, + const DataType &dt); + +// Check whether the device has native float atomic add for dt. +// When is_shared=true, checks shared/workgroup capabilities; +// when is_shared=false, checks buffer capabilities. +bool has_native_float_atomic_add(const DeviceCapabilityConfig &caps, + const DataType &dt, + bool is_shared); + +} // namespace spirv +} // namespace quadrants::lang diff --git a/quadrants/inc/rhi_constants.inc.h b/quadrants/inc/rhi_constants.inc.h index 8617ecafe0..725171d5db 100644 --- a/quadrants/inc/rhi_constants.inc.h +++ b/quadrants/inc/rhi_constants.inc.h @@ -24,6 +24,9 @@ PER_DEVICE_CAPABILITY(spirv_has_atomic_float_minmax) PER_DEVICE_CAPABILITY(spirv_has_atomic_float64) // load, store, exchange PER_DEVICE_CAPABILITY(spirv_has_atomic_float64_add) PER_DEVICE_CAPABILITY(spirv_has_atomic_float64_minmax) +PER_DEVICE_CAPABILITY(spirv_has_shared_atomic_float_add) +PER_DEVICE_CAPABILITY(spirv_has_shared_atomic_float64_add) +PER_DEVICE_CAPABILITY(spirv_has_shared_atomic_float16_add) PER_DEVICE_CAPABILITY(spirv_has_variable_ptr) PER_DEVICE_CAPABILITY(spirv_has_physical_storage_buffer) PER_DEVICE_CAPABILITY(spirv_has_subgroup_basic) diff --git a/quadrants/rhi/metal/metal_device.mm b/quadrants/rhi/metal/metal_device.mm index e5713a7d34..9c7f07d7b0 100644 --- a/quadrants/rhi/metal/metal_device.mm +++ b/quadrants/rhi/metal/metal_device.mm @@ -1071,6 +1071,10 @@ DeviceCapabilityConfig collect_metal_device_caps(MTLDevice_id mtl_device) { caps.set(DeviceCapability::spirv_has_int64, 1); caps.set(DeviceCapability::spirv_has_physical_storage_buffer, 1); } + // Metal supports 64-bit atomics (atomic) on Apple7+ and Mac2+. + if (feature_floating_point_atomics) { + caps.set(DeviceCapability::spirv_has_atomic_int64, 1); + } if (feature_floating_point_atomics) { // FIXME: (penguinliong) For some reason floating point atomics doesn't // work and breaks the FEM99/FEM128 examples. Should consider add them back diff --git a/quadrants/rhi/vulkan/vulkan_device_creator.cpp b/quadrants/rhi/vulkan/vulkan_device_creator.cpp index 7a51c5c6a6..fa6d04a641 100644 --- a/quadrants/rhi/vulkan/vulkan_device_creator.cpp +++ b/quadrants/rhi/vulkan/vulkan_device_creator.cpp @@ -684,6 +684,9 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) { VkPhysicalDeviceVariablePointersFeaturesKHR variable_ptr_feature{}; variable_ptr_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTERS_FEATURES_KHR; + VkPhysicalDeviceShaderAtomicInt64Features shader_atomic_int64_feature{}; + shader_atomic_int64_feature.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_INT64_FEATURES; VkPhysicalDeviceShaderAtomicFloatFeaturesEXT shader_atomic_float_feature{}; shader_atomic_float_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; @@ -736,6 +739,19 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) { pNextEnd = &variable_ptr_feature.pNext; } + // Atomic int64 (promoted to Vulkan 1.2 core) + if (CHECK_VERSION(1, 2) || + CHECK_EXTENSION(VK_KHR_SHADER_ATOMIC_INT64_EXTENSION_NAME)) { + features2.pNext = &shader_atomic_int64_feature; + vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); + if (shader_atomic_int64_feature.shaderBufferInt64Atomics || + shader_atomic_int64_feature.shaderSharedInt64Atomics) { + caps.set(DeviceCapability::spirv_has_atomic_int64, true); + } + *pNextEnd = &shader_atomic_int64_feature; + pNextEnd = &shader_atomic_int64_feature.pNext; + } + // Atomic float if (CHECK_EXTENSION(VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME)) { features2.pNext = &shader_atomic_float_feature; @@ -752,6 +768,12 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) { if (shader_atomic_float_feature.shaderBufferFloat64Atomics) { caps.set(DeviceCapability::spirv_has_atomic_float64, true); } + if (shader_atomic_float_feature.shaderSharedFloat32AtomicAdd) { + caps.set(DeviceCapability::spirv_has_shared_atomic_float_add, true); + } + if (shader_atomic_float_feature.shaderSharedFloat64AtomicAdd) { + caps.set(DeviceCapability::spirv_has_shared_atomic_float64_add, true); + } *pNextEnd = &shader_atomic_float_feature; pNextEnd = &shader_atomic_float_feature.pNext; } @@ -769,6 +791,9 @@ void VulkanDeviceCreator::create_logical_device(bool manual_create) { if (shader_atomic_float_2_feature.shaderBufferFloat16Atomics) { caps.set(DeviceCapability::spirv_has_atomic_float16, true); } + if (shader_atomic_float_2_feature.shaderSharedFloat16AtomicAdd) { + caps.set(DeviceCapability::spirv_has_shared_atomic_float16_add, true); + } if (shader_atomic_float_2_feature.shaderBufferFloat32AtomicMinMax) { caps.set(DeviceCapability::spirv_has_atomic_float_minmax, true); } diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index c4b1fa00ff..fa3cb22f1a 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -426,6 +426,47 @@ def max_kernel() -> qd.f32: assert max_kernel() == -1.0 +@pytest.mark.parametrize("op", ["add", "sub", "min", "max"]) +@pytest.mark.parametrize("dtype", [qd.f16, qd.f32, qd.f64]) +@test_utils.test() +def test_atomic_float_ops(op, dtype): + if qd.cfg.arch in (qd.vulkan, qd.metal): + caps = qd.lang.impl.get_runtime().prog.get_device_caps() + # f16 CAS requires 16-bit integer atomics, unsupported on MoltenVK/Metal + if dtype == qd.f16 and not caps.get(qd._lib.core.DeviceCapability.spirv_has_atomic_float16): + pytest.skip("Device does not support f16 atomics") + if dtype == qd.f64 and not caps.get(qd._lib.core.DeviceCapability.spirv_has_float64): + pytest.skip("Device does not support f64") + block_dim = 32 + N = block_dim * 4 + SCALE = 0.1523 + atomic_op = getattr(qd, f"atomic_{op}") + + @qd.kernel + def kern(out: qd.types.ndarray()): + # Use multiple threads to test concurrent atomicity + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + val = qd.cast(tid * SCALE, dtype) + atomic_op(out[0], val) + + arr = qd.ndarray(dtype, (1,)) + arr[0] = 0.0 + kern(arr) + # 4 blocks each contributing SCALE * (0 + 1 + ... + 31) + nblocks = N // block_dim + per_block_sum = SCALE * block_dim * (block_dim - 1) / 2.0 + expected = { + "add": per_block_sum * nblocks, + "sub": -per_block_sum * nblocks, + "min": 0.0, + "max": (block_dim - 1) * SCALE, + } + rtol = {qd.f16: 1e-3, qd.f64: 1e-10}.get(dtype, 1e-6) + assert arr[0] == test_utils.approx(expected[op], rel=rtol) + + @test_utils.test() def test_atomic_mul_f32(): @qd.kernel diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index 761aa075ab..6999075d6e 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -22,7 +22,7 @@ (1, 1, qd.i8, qd.f32), # different shape, different dtype ], ) -@test_utils.test(arch=[qd.cuda, qd.metal]) +@test_utils.test(arch=qd.gpu) def test_shared_array_not_accumulated_across_offloads(num_dim, first_shape_delta_size, dtype1, dtype2): # Execute 2 successive offloaded tasks both allocating more than half of # the maximum shared memory available on the device to make sure shared @@ -214,7 +214,7 @@ def scaled_reduce_shared( assert np.allclose(reference.to_numpy(), a_arr.to_numpy()) -@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.amdgpu]) +@test_utils.test(arch=qd.gpu) def test_multiple_shared_array(): assert qd.cfg is not None if qd.cfg.arch == qd.amdgpu: @@ -222,8 +222,11 @@ def test_multiple_shared_array(): block_dim = 128 nBlocks = 64 N = nBlocks * block_dim * 4 - v_arr = np.random.randn(N).astype(np.float32) - d_arr = np.random.randn(N).astype(np.float32) + # Seed the RNG to avoid flaky failures from FP accumulation-order + # differences between the reference and tiled shared-memory kernels. + rng = np.random.RandomState(42) + v_arr = rng.randn(N).astype(np.float32) + d_arr = rng.randn(N).astype(np.float32) a_arr = np.zeros(N).astype(np.float32) reference = np.zeros(N).astype(np.float32) @@ -274,7 +277,7 @@ def calc_shared_array( assert np.allclose(reference, a_arr, rtol=1e-4) -@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.amdgpu]) +@test_utils.test(arch=qd.gpu) def test_shared_array_atomics(): N = 256 block_dim = 32 @@ -302,7 +305,133 @@ def atomic_test(out: qd.types.ndarray()): assert arr[224] == sum -@test_utils.test(arch=[qd.cuda]) +@test_utils.test(arch=qd.gpu) +def test_shared_array_int_atomic_mul(): + # Regression test for dest_is_ptr guard on integer atomic_mul. + # Without the guard, at_buffer() crashes on shared pointers. + N = 64 + block_dim = 4 + + @qd.kernel + def kern(out: qd.types.ndarray()): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + sharr = qd.simt.block.SharedArray((block_dim,), qd.i32) + sharr[tid] = tid + 1 + qd.simt.block.sync() + qd.atomic_mul(sharr[0], sharr[tid]) + qd.simt.block.sync() + out[i] = sharr[0] + + arr = qd.ndarray(qd.i32, (N,)) + kern(arr) + # sharr[0] starts as 1, then *= 1, *= 2, *= 3, *= 4 -> 24 + for idx in (0, 4, 8, 60): + assert arr[idx] == 24 + + +@pytest.mark.parametrize("op", ["add", "sub", "min", "max"]) +@pytest.mark.parametrize("dtype", [qd.f16, qd.f32, qd.f64]) +@test_utils.test(arch=qd.gpu) +def test_shared_array_float_atomics(op, dtype): + if dtype == qd.f64: + if qd.cfg.arch in (qd.vulkan, qd.metal): + caps = qd.lang.impl.get_runtime().prog.get_device_caps() + if not caps.get(qd._lib.core.DeviceCapability.spirv_has_float64): + pytest.skip("Device does not support f64") + N = 256 + block_dim = 32 + SCALE = 0.1523 # fractional so values are truly non-integer floats + # Arithmetic sum: SCALE * (0 + 1 + ... + block_dim-1) + expected_sum = SCALE * block_dim * (block_dim - 1) / 2.0 + atomic_op = getattr(qd, f"atomic_{op}") + + def make_kernel(atomic_fn): + @qd.kernel + def kern(out: qd.types.ndarray()): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + sharr = qd.simt.block.SharedArray((block_dim,), dtype) + val = qd.cast(tid * SCALE, dtype) + sharr[tid] = val + qd.simt.block.sync() + atomic_fn(sharr[0], val) + qd.simt.block.sync() # wait for all threads' atomics to complete + out[i] = sharr[0] + + return kern + + expected = { + "add": expected_sum, + "sub": -expected_sum, + "min": 0.0, + "max": (block_dim - 1) * SCALE, + } + rtol = 1e-3 if dtype == qd.f16 else 1e-6 + arr = qd.ndarray(qd.f32, (N)) + make_kernel(atomic_op)(arr) + for idx in (0, 31, 32, 255): + assert arr[idx] == test_utils.approx(expected[op], rel=rtol) + + +@pytest.mark.parametrize( + "dtype", + [qd.i8, qd.i16, qd.i32, qd.u8, qd.u16, qd.u32, qd.f16, qd.f32, qd.f64, qd.u1], +) +@test_utils.test(arch=qd.gpu) +def test_shared_array_dtypes(dtype): + if dtype == qd.f64: + if qd.cfg.arch in (qd.vulkan, qd.metal): + caps = qd.lang.impl.get_runtime().prog.get_device_caps() + if not caps.get(qd._lib.core.DeviceCapability.spirv_has_float64): + pytest.skip("Device does not support f64") + N = 128 + block_dim = 32 + SCALE = 0.1523 + # Use f32 as output type for sub-32-bit types that ndarrays can't represent + out_dtype = qd.f32 if dtype in (qd.f16, qd.u1) else dtype + + @qd.kernel + def kern(inp: qd.types.ndarray(), out: qd.types.ndarray()): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + sharr = qd.simt.block.SharedArray((block_dim,), dtype) + # Store from input tensor (not just tid) to prevent shortcutting + sharr[tid] = qd.cast(inp[i], dtype) + qd.simt.block.sync() + # Read from a different thread's slot to force shared memory use + out[i] = sharr[(tid + 1) % block_dim] + + if dtype in (qd.f16, qd.f32, qd.f64): + inp_vals = np.array([i % block_dim * SCALE for i in range(N)], dtype=np.float32) + inp_dtype = out_dtype + else: + inp_vals = np.array([i % block_dim for i in range(N)], dtype=np.int32) + # Use i32 for bool input - SPIR-V doesn't support f32 -> u1 cast + inp_dtype = qd.i32 if dtype == qd.u1 else out_dtype + inp = qd.ndarray(inp_dtype, (N,)) + inp.from_numpy(inp_vals) + arr = qd.ndarray(out_dtype, (N,)) + kern(inp, arr) + + rtol = {qd.f16: 1e-3, qd.f64: 1e-10}.get(dtype, 1e-6) + for block_start in (0, 32, 64, 96): + for tid in range(block_dim): + neighbor = (tid + 1) % block_dim + if dtype in (qd.f16, qd.f32, qd.f64): + expected = neighbor * SCALE + assert arr[block_start + tid] == test_utils.approx(expected, rel=rtol) + elif dtype == qd.u1: + # qd.cast to u1 maps nonzero -> 1, zero -> 0 + assert arr[block_start + tid] == (0 if neighbor == 0 else 1) + else: + assert arr[block_start + tid] == neighbor + + +@test_utils.test(arch=qd.gpu) def test_shared_array_tensor_type(): data_type = vec4 block_dim = 16 @@ -324,10 +453,14 @@ def test(): y[tid] += shared_mem[tid] test() - assert (y.to_numpy()[0] == [4.0, 8.0, 12.0, 16.0]).all() + # Check all tids, not just tid=0. The shared array is flattened from + # vec4[16] to f32[64] in SPIR-V, so a stride bug (e.g. accessing element + # tid+c instead of tid*4+c) would produce wrong values for tid>0 but + # correct values for tid=0 since 0*anything==0. + assert (y.to_numpy() == [[4.0, 8.0, 12.0, 16.0]] * block_dim).all() -@test_utils.test(arch=[qd.cuda], debug=True) +@test_utils.test(arch=qd.gpu, debug=True) def test_shared_array_matrix(): @qd.kernel def foo():