Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions quadrants/codegen/spirv/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ target_sources(spirv_codegen
kernel_utils.cpp
snode_struct_compiler.cpp
spirv_codegen.cpp
spirv_shared_array_retyping.cpp
spirv_ir_builder.cpp
spirv_types.cpp
compiled_kernel_data.cpp
Expand Down
10 changes: 10 additions & 0 deletions quadrants/codegen/spirv/detail/spirv_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ class TaskCodegen : public IRVisitor {
std::unordered_map<int, GetRootStmt *>
root_stmts_; // maps root id to get root stmt
std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_;
// Shared float AllocaStmts targeted by atomics, populated by
// scan_shared_atomic_allocs() before codegen. Value = true means the alloca
// has non-add ops (CAS unconditionally needed); false = add-only (native
// shared float atomics can be used if the device supports them).
std::unordered_map<const Stmt *, bool> shared_float_allocas_with_atomic_rmw_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Propagated from shared_float_allocas_with_atomic_rmw_ to derived
// MatrixPtrStmt nodes during codegen, so that load/store/atomic visitors
// know to bitcast. E.g. if `sharr` (AllocaStmt) is retyped, then
// `sharr[0]` (MatrixPtrStmt) is added here during visit(MatrixPtrStmt).
Comment thread
hughperkins marked this conversation as resolved.
std::unordered_set<const Stmt *> uint_backed_shared_float_ptr_stmts_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::unordered_map<std::vector<int>, Value, hashing::Hasher<std::vector<int>>>
argid_to_tex_value_;

Expand Down
98 changes: 83 additions & 15 deletions quadrants/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "quadrants/codegen/spirv/kernel_utils.h"
#include "quadrants/codegen/spirv/spirv_ir_builder.h"
#include "quadrants/codegen/spirv/detail/spirv_codegen.h"
#include "quadrants/codegen/spirv/spirv_shared_array_retyping.h"
#include "quadrants/ir/transforms.h"
#include "quadrants/math/arithmetic.h"
#include "quadrants/codegen/ir_dump.h"
Expand Down Expand Up @@ -119,6 +120,9 @@ TaskCodegen::Result TaskCodegen::run() {
kernel_function_ = ir_->new_function(); // void main();
ir_->debug_name(spv::OpName, kernel_function_, "main");

scan_shared_atomic_allocs(task_ir_->body.get(),
shared_float_allocas_with_atomic_rmw_);

if (task_ir_->task_type == OffloadedTaskType::serial) {
generate_serial_kernel(task_ir_);
} else if (task_ir_->task_type == OffloadedTaskType::range_for) {
Expand Down Expand Up @@ -270,11 +274,21 @@ void TaskCodegen::visit(ConstStmt *const_stmt) {

void TaskCodegen::visit(AllocaStmt *alloca) {
spirv::Value ptr_val;
// alloca->ret_type is a pointer to the stored type; ptr_removed() gives the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// stored type itself (e.g. TensorType<32 x f32> for a 32-element array).
auto alloca_type = alloca->ret_type.ptr_removed();
// Shared array is always modeled as a tensor type, i.e. an array of scalars.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
// Do NOT initialize elem_num/elem_type here - the helper flattens nested
// tensor types (e.g. vec3 -> 3xf32) before computing them. Pre-initializing
// with get_primitive_type(tensor_type->get_element_type()) would crash on
// nested tensor types like Tensor(3) f32.
int elem_num;
spirv::SType elem_type;
maybe_retype_alloca(*ir_, *caps_, alloca, tensor_type,
shared_float_allocas_with_atomic_rmw_,
uint_backed_shared_float_ptr_stmts_, elem_num,
elem_type);
spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num);
if (alloca->is_shared) { // for shared memory / workgroup memory
ptr_val = ir_->alloca_workgroup_array(arr_type);
Expand All @@ -297,12 +311,18 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) {
spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
auto dt = stmt->element_type().ptr_removed();
if (stmt->offset_used_as_index()) {
if (stmt->origin->is<AllocaStmt>()) {
// Origin is a local/shared array allocation or a derived pointer from one
// - use OpAccessChain or OpPtrAccessChain respectively.
if (stmt->origin->is<AllocaStmt>() ||
origin_val.stype.flag == TypeKind::kPtr) {
maybe_retype_derived_ptr(*ir_, stmt->origin, stmt, dt,
uint_backed_shared_float_ptr_stmts_);
spirv::SType ptr_type = ir_->get_pointer_type(
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
ptr_val =
ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, offset_val);
if (stmt->origin->as<AllocaStmt>()->is_shared) {
auto op = stmt->origin->is<AllocaStmt>() ? spv::OpAccessChain
: spv::OpPtrAccessChain;
ptr_val = ir_->make_value(op, ptr_type, origin_val, offset_val);
if (auto *a = stmt->origin->cast<AllocaStmt>(); a && a->is_shared) {
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
}
} else if (stmt->origin->is<GlobalTemporaryStmt>()) {
Expand All @@ -324,14 +344,22 @@ void TaskCodegen::visit(MatrixPtrStmt *stmt) {
void TaskCodegen::visit(LocalLoadStmt *stmt) {
auto ptr = stmt->src;
spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
spirv::Value val = ir_->load_variable(
ptr_val, ir_->get_primitive_type(stmt->element_type()));
spirv::Value val;
if (uint_backed_shared_float_ptr_stmts_.count(ptr)) {
val = load_uint_backed_shared_float(*ir_, ptr_val, stmt->element_type());
} else {
val = ir_->load_variable(ptr_val,
ir_->get_primitive_type(stmt->element_type()));
}
ir_->register_value(stmt->raw_name(), val);
}

void TaskCodegen::visit(LocalStoreStmt *stmt) {
spirv::Value ptr_val = ir_->query_value(stmt->dest->raw_name());
spirv::Value val = ir_->query_value(stmt->val->raw_name());
if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) {
val = float_to_shared_uint(*ir_, val, stmt->val->element_type());
}
ir_->store_variable(ptr_val, val);
}

Expand Down Expand Up @@ -1521,13 +1549,17 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) {

spirv::Value addr_ptr;
spirv::Value dest_val = ir_->query_value(stmt->dest->raw_name());
// Shared arrays have already created an accesschain, use it directly.
// Shared arrays already have a pointer from OpAccessChain (dest_is_ptr=true).
// at_buffer() looks up ptr_to_buffers_ to find the StorageBuffer and compute
// a byte offset - shared/workgroup arrays aren't in ptr_to_buffers_, so
// at_buffer() would fail on them.
const bool dest_is_ptr = dest_val.stype.flag == TypeKind::kPtr;

// The native-add branches originally called at_buffer() directly, but shared
// arrays can now reach this path, so all branches need the dest_is_ptr guard.
if (dt->is_primitive(PrimitiveTypeID::f64)) {
if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) &&
stmt->op_type == AtomicOpType::add) {
addr_ptr = at_buffer(stmt->dest, dt);
addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the uint backed shared mean we have to change this line?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment.

} else {
addr_ptr = dest_is_ptr
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we refactorziing by the way? It's just extra codde for me to read. Is this reafctoziing necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

? dest_val
Expand All @@ -1536,7 +1568,19 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) {
} else if (dt->is_primitive(PrimitiveTypeID::f32)) {
if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) &&
stmt->op_type == AtomicOpType::add) {
addr_ptr = at_buffer(stmt->dest, dt);
addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the uint backed shared mean we have to change this line?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment.

} else {
addr_ptr = dest_is_ptr
? dest_val
: at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt));
}
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
// f16 needs the same uint-typed pointer as f32/f64 for the CAS path.
// Without this, at_buffer returns pointer-to-f16 but the CAS loop uses
// OpAtomicLoad(u16, ...) causing a SPIR-V type mismatch.
if (caps_->get(DeviceCapability::spirv_has_atomic_float16_add) &&
stmt->op_type == AtomicOpType::add) {
addr_ptr = dest_is_ptr ? dest_val : at_buffer(stmt->dest, dt);
} else {
addr_ptr = dest_is_ptr
? dest_val
Expand All @@ -1549,7 +1593,9 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) {
auto ret_type = ir_->get_primitive_type(dt);

if (is_real(dt)) {
spv::Op atomic_fp_op;
// Only initialized for add (the only op with native float atomic support).
// Safe: use_native_atomics is only true when op_type == add.
spv::Op atomic_fp_op = spv::OpNop;
if (stmt->op_type == AtomicOpType::add) {
atomic_fp_op = spv::OpAtomicFAddEXT;
}
Expand All @@ -1572,11 +1618,26 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) {
use_native_atomics = true;
}
}
// The checks above use buffer capabilities. For shared pointers, override
// with shared capabilities (buffer and shared support are independent).
if (dest_is_ptr && stmt->op_type == AtomicOpType::add) {
use_native_atomics = has_native_float_atomic_add(*caps_, dt, true);
}
// Uint-retyped shared arrays have a uint pointer - native float atomics
// would produce invalid SPIR-V on them.
if (uint_backed_shared_float_ptr_stmts_.count(stmt->dest)) {
use_native_atomics = false;
}

if (use_native_atomics) {
val = ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr,
/*scope=*/ir_->const_i32_one_,
/*semantics=*/ir_->const_i32_zero_, data);
} else if (dest_is_ptr) {
// Shared float arrays use uint-backed CAS (width-aware for f16->u32).
// Integer shared atomics don't need this - they use native OpAtomicIAdd
// etc. directly on the shared pointer.
val = shared_float_atomic(*ir_, stmt->op_type, addr_ptr, data, dt);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is dest_is_ptr not true for shared int atomics?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment.

} else {
val = ir_->float_atomic(stmt->op_type, addr_ptr, data, dt);
}
Expand All @@ -1590,7 +1651,11 @@ void TaskCodegen::visit(AtomicOpStmt *stmt) {
op = spv::OpAtomicISub;
use_native_atomics = true;
} else if (stmt->op_type == AtomicOpType::mul) {
addr_ptr = at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt));
// dest_is_ptr guard needed here too - at_buffer would crash on shared
// integer arrays (same reason as the float branches above).
addr_ptr = dest_is_ptr
? dest_val
: at_buffer(stmt->dest, ir_->get_quadrants_uint_type(dt));
val = ir_->integer_atomic(stmt->op_type, addr_ptr, data, dt);
use_native_atomics = false;
} else if (stmt->op_type == AtomicOpType::min) {
Expand Down Expand Up @@ -2096,6 +2161,9 @@ void TaskCodegen::generate_struct_for_kernel(OffloadedStmt *stmt) {
task_attribs_.buffer_binds = get_buffer_binds();
}

// Return the address in device memory for a global/storage-buffer access.
// Only works for device-buffer-backed pointers (via ptr_to_buffers_), not
// workgroup arrays - those already have a pointer from OpAccessChain.
spirv::Value TaskCodegen::at_buffer(const Stmt *ptr, DataType dt) {
spirv::Value ptr_val = ir_->query_value(ptr->raw_name());

Expand Down
38 changes: 30 additions & 8 deletions quadrants/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ void IRBuilder::init_header() {
if (caps_->get(cap::spirv_has_int64)) {
ib_.begin(spv::OpCapability).add(spv::CapabilityInt64).commit(&header_);
}
if (caps_->get(cap::spirv_has_atomic_int64)) {
// Required for OpAtomicLoad/OpAtomicCompareExchange on u64, used by
// the CAS-based f64 shared float atomic emulation path.
ib_.begin(spv::OpCapability)
.add(spv::CapabilityInt64Atomics)
.commit(&header_);
}
if (caps_->get(cap::spirv_has_float16)) {
ib_.begin(spv::OpCapability).add(spv::CapabilityFloat16).commit(&header_);
}
Expand Down Expand Up @@ -1075,6 +1082,8 @@ Value IRBuilder::float_atomic(AtomicOpType op_type,
Value addr_ptr,
Value data,
const DataType &dt) {
// Use dt-derived type instead of t_fp32_ so FMin/FMax work for f16/f64.
auto float_type = get_primitive_type(dt);
if (op_type == AtomicOpType::add) {
return atomic_operation(
addr_ptr, data, [&](Value lhs, Value rhs) { return add(lhs, rhs); },
Expand All @@ -1091,14 +1100,14 @@ Value IRBuilder::float_atomic(AtomicOpType op_type,
return atomic_operation(
addr_ptr, data,
[&](Value lhs, Value rhs) {
return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs);
return call_glsl450(float_type, /*FMin*/ 37, lhs, rhs);
Comment thread
hughperkins marked this conversation as resolved.
},
dt);
} else if (op_type == AtomicOpType::max) {
return atomic_operation(
addr_ptr, data,
[&](Value lhs, Value rhs) {
return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs);
return call_glsl450(float_type, /*FMax*/ 40, lhs, rhs);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

},
dt);
} else {
Expand All @@ -1124,7 +1133,13 @@ Value IRBuilder::atomic_operation(Value addr_ptr,
std::function<Value(Value, Value)> op,
const DataType &dt) {
SType out_type = get_primitive_type(dt);
SType res_type = get_primitive_uint_type(dt);
// Device-buffer pointers are uint-typed (from at_buffer), so CAS uses uint.
// Workgroup (shared) pointers keep their original type (e.g. i32). Using uint
// on a signed pointer causes Metal's atomic_compare_exchange to reject the
// shader due to signed/unsigned type mismatch.
const bool is_workgroup =
addr_ptr.stype.storage_class == spv::StorageClassWorkgroup;
SType res_type = is_workgroup ? out_type : get_primitive_uint_type(dt);
Value ret_val_int = alloca_variable(res_type);

// do-while
Expand All @@ -1150,10 +1165,15 @@ Value IRBuilder::atomic_operation(Value addr_ptr,
Value old_val = make_value(spv::OpAtomicLoad, res_type, addr_ptr,
/*scope=*/const_i32_one_,
/*semantics=*/const_i32_zero_);
// int new = dataTypeBitsToInt(atomic_op(intBitsToDataType(old), data));
Value old_data_value = make_value(spv::OpBitcast, out_type, old_val);
// Bitcast uint<->float for the operation. Skip when types already match
// (integer workgroup path where res_type == out_type).
Value old_data_value = (out_type.id != res_type.id)
? make_value(spv::OpBitcast, out_type, old_val)
: old_val;
Value new_data_value = op(old_data_value, data);
Value new_val = make_value(spv::OpBitcast, res_type, new_data_value);
Value new_val = (out_type.id != res_type.id)
? make_value(spv::OpBitcast, res_type, new_data_value)
: new_data_value;
// int loaded = atomicCompSwap(vals[0], old, new);
/*
* Don't need this part, theoretically
Expand Down Expand Up @@ -1188,8 +1208,10 @@ Value IRBuilder::atomic_operation(Value addr_ptr,
}
start_label(exit);

return make_value(spv::OpBitcast, out_type,
load_variable(ret_val_int, res_type));
Value ret_loaded = load_variable(ret_val_int, res_type);
return (out_type.id != res_type.id)
? make_value(spv::OpBitcast, out_type, ret_loaded)
: ret_loaded;
}

Value IRBuilder::rand_u32(Value global_tmp_) {
Expand Down
Loading
Loading