Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/user_guide/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -465,23 +465,23 @@ Quadrants computes that worst case at launch time - in this example, the max of

### Evaluation paths

The compiler picks one of two evaluation paths to compute the maximum based on the bound expression's structure:
The compiler picks one of two evaluation paths to compute the maximum based on the backend and the bound expression's structure:

- **Parallel:** the maximum is computed with a tiny parallel reduction kernel for efficiency. The reducer accepts a common subset of bound expressions:
- **Parallel (GPU only):** the maximum is computed with a tiny parallel reduction kernel for efficiency. The reducer accepts a common subset of bound expressions:
- **Integer ndarray or field read** up to 32 bits wide, indexed by literal constants or outer-loop variables: `arr[i, j]`, `field[i]`.
- **Shape term**: `arr.shape[k]`.
- **Literal integer constant**: `42`.
- **Arithmetic combinator**: any `+`, `-`, `*`, `max` of the above.
- **Sequential:** the fallback path, used whenever the parallel path doesn't support the bound expression. Quadrants walks the bound expression one outer-loop iteration at a time on a single thread; the adstack is sized identically, only the upfront cost differs. This path accepts everything the parallel path does, plus:
- **Sequential:** the fallback path, used whenever the parallel path doesn't support the bound expression. Quadrants walks the bound expression one outer-loop iteration at a time on a single thread (host-side on CPU, single-thread on-device kernel on GPU); the adstack is sized identically, only the upfront cost differs. This path accepts everything the parallel path does, plus:
- **Arithmetic-indexed read**: `arr[i // 2]`, `arr[i % 4]`.
- **Indirect / nested read**: `arr1[arr2[i]]`, `my_field[arr[i]]`.

### Nested loops

Quadrants supports arbitrarily nested loops. When the bound expression itself contains another enclosed loop whose own bound expression must be reduced first, the enclosing bound expression takes the parallel path only if every nested bound expression also fits the parallel-path grammar; otherwise it falls back to the sequential walk. This keeps the runtime from mixing parallel and sequential evaluators inside a single bound expression, which would otherwise force per-iteration kernel launches.
Quadrants supports arbitrarily nested loops. When the bound expression itself contains another enclosed loop whose own bound expression must be reduced first, the enclosing bound expression takes the parallel path only if every nested bound expression is also supported by the parallel path; otherwise it falls back to the sequential walk. This keeps the runtime from mixing parallel and sequential evaluators inside a single bound expression, which would otherwise force per-iteration kernel launches.

### Sequential walk cap

The sequential walk's outer loop is artificially capped at 2^24 = 16 777 216 iterations to keep both the walk time and the read-tracking memory bounded; past that the kernel raises `RuntimeError: ... iteration count ... exceeds the 16777216 guard`. In the example above, the iteration count of the enclosed loop takes the sequential path because of the `i // 2` index, so it would raise at launch if `arr.shape[0] > (1 << 24)`.
The sequential walk's outer loop is artificially capped at 2^24 = 16 777 216 iterations on GPU backends to keep the walk time bounded; past that the kernel raises `RuntimeError: ... iteration count ... exceeds the 16777216 guard`. In the example above, the iteration count of the enclosed loop takes the sequential path because of the `i // 2` index, so it would raise at launch on GPU backends if `arr.shape[0] > (1 << 24)`.

To circumvent this limitation, rewrite the bound expression to unlock the parallel path (e.g. precompute `bounds[i] = arr[i // 2]` into a persistent separate buffer, pass `bounds` in as an input, and use `for j in range(bounds[i]):`), or keep the outer loop count below 2^24.
14 changes: 10 additions & 4 deletions quadrants/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1994,10 +1994,16 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() {
current_task->ad_stack.allocas = ad_stack_allocas_info_;
current_task->ad_stack.size_exprs = ad_stack_size_exprs_;
current_task->ad_stack.bound_expr = ad_stack_static_bound_expr_;
// recognize `MaxOverRange` nodes that the runtime can reduce in parallel via the dedicated max-reducer dispatch
// instead of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order
// as the pre-scan above).
current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_);
// Recognize `MaxOverRange` nodes the runtime can reduce in parallel via the dedicated max-reducer dispatch instead
// of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order as the
// pre-scan above). Skip on CPU: the host evaluator's `MaxOverRange` loop in `program/adstack/eval.cpp` does the
// same serial walk, and dispatching the runtime helper would only add per-launch setup cost (params blob encode,
// body bytecode encode, observation bookkeeping, JIT call) with no compute parallelism to amortize. The host
// evaluator handles every iteration count up to its own cap (`UINT32_MAX` on CPU; see `eval.cpp`). On CUDA /
// AMDGPU the parallel reducer is the whole point of the dispatch and the recognizer stays active.
if (!arch_is_cpu(compile_config.arch)) {
current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_);
}
// Snodes the task body mutates. Persisted on `OffloadedTask::snode_writes` so the LLVM
// launcher can invalidate the per-task adstack metadata cache when a kernel that runs in
// between mutated a SNode an enclosing `size_expr::FieldLoad` reads. Mirrors the SPIR-V
Expand Down
128 changes: 32 additions & 96 deletions quadrants/program/adstack/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,97 +20,42 @@ namespace quadrants::lang {

namespace {

// Read the input that `obs` describes against the live state and `ctx`. Caller compares the result to
// `obs.observed_value` to decide whether the cached `SizeExprCacheEntry` is still valid. Each `obs.kind`
// mirrors the corresponding leaf in `evaluate_field_load` / `evaluate_external_tensor_shape` /
// `evaluate_external_tensor_read`.
int64_t replay_one_observation(const AdStackCache::SizeExprReadObservation &obs,
Program *prog,
LaunchContextBuilder *ctx) {
// Decide whether the input that `obs` describes is still consistent with the recorded state. Returns true iff the
// cached `SizeExprCacheEntry` is still valid for this observation. FieldLoadObs / ExternalReadObs use the per-buffer
// gen counter (`snode_write_gen` / `ndarray_data_gen`) as the sole staleness signal: a gen-counter advance forces a
// re-walk regardless of whether the read cells themselves changed. ExternalShapeObs has no gen counter (shapes are
// launch-arg metadata, not buffer content), so it falls back to value comparison against `observed_value`.
bool replay_observation_is_fresh(const AdStackCache::SizeExprReadObservation &obs,
Program *prog,
LaunchContextBuilder *ctx) {
using Obs = AdStackCache::SizeExprReadObservation;
switch (obs.kind) {
case Obs::FieldLoadObs: {
// Gen-counter fast skip: when no kernel has bumped this SNode's write generation since record time, the
// underlying field value cannot have changed and we can return the recorded `observed_value` without dispatching
// a reader kernel. The dispatch is the dominant per-launch cost on the hot path for steady-state reverse-mode
// loops with stable bounds.
if (prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen) {
return obs.observed_value;
}
// Max-reducer body FieldLoadObs (bound-var-indexed leaves) records `indices = {}` since the body is evaluated at
// every cross-product point and there is no single canonical index to re-read. The gen counter is the only valid
// staleness signal in that mode; a gen mismatch unconditionally invalidates the cache.
if (obs.indices.empty()) {
return obs.observed_value + 1;
}
int64_t v = read_field_with_launch_cache(obs.snode_id, obs.indices, prog);
if (v == std::numeric_limits<int64_t>::min()) {
return obs.observed_value + 1; // force a mismatch if SNode disappeared
}
return v;
}
case Obs::ExternalShapeObs: {
if (ctx == nullptr) {
return obs.observed_value + 1;
}
std::vector<int> arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end());
arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY);
arg_indices.push_back(obs.arg_shape_axis);
return static_cast<int64_t>(ctx->get_struct_arg_host<int32_t>(arg_indices));
}
case Obs::FieldLoadObs:
return prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen;
case Obs::ExternalReadObs: {
if (ctx == nullptr || obs.arg_id_path.empty()) {
return obs.observed_value + 1;
if (ctx == nullptr || obs.arg_id_path.empty() || prog == nullptr) {
return false;
}
int arg_id = obs.arg_id_path[0];
ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY};
auto it = ctx->array_ptrs.find(key);
if (it == ctx->array_ptrs.end()) {
return obs.observed_value + 1;
return false;
}
void *data_ptr = it->second;
// Gen-counter fast skip: when the data pointer is the same `DeviceAllocation *` we observed at record
// time AND its data generation has not been bumped since (no kernel write, no host-side `Ndarray.write`
// / `fill`), the underlying scalar cannot have changed and we can return the recorded value without
// dereferencing the device pointer (which on GPU would be a DtoH copy, on CPU a host load).
if (prog != nullptr && data_ptr == obs.observed_devalloc &&
prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen) {
return obs.observed_value;
}
int64_t linear = 0;
int64_t stride = 1;
for (std::size_t i = obs.indices.size(); i > 0; --i) {
linear += static_cast<int64_t>(obs.indices[i - 1]) * stride;
if (i - 1 > 0) {
std::vector<int> sh_idx(obs.arg_id_path.begin(), obs.arg_id_path.end());
sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY);
sh_idx.push_back(static_cast<int>(i - 1));
stride *= static_cast<int64_t>(ctx->get_struct_arg_host<int32_t>(sh_idx));
}
}
switch (static_cast<PrimitiveTypeID>(obs.prim_dt)) {
case PrimitiveTypeID::i32:
return static_cast<int64_t>(static_cast<int32_t *>(data_ptr)[linear]);
case PrimitiveTypeID::i64:
return static_cast<int64_t *>(data_ptr)[linear];
case PrimitiveTypeID::u32:
return static_cast<int64_t>(static_cast<uint32_t *>(data_ptr)[linear]);
case PrimitiveTypeID::u64:
return static_cast<int64_t>(static_cast<uint64_t *>(data_ptr)[linear]);
case PrimitiveTypeID::i16:
return static_cast<int64_t>(static_cast<int16_t *>(data_ptr)[linear]);
case PrimitiveTypeID::u16:
return static_cast<int64_t>(static_cast<uint16_t *>(data_ptr)[linear]);
case PrimitiveTypeID::i8:
return static_cast<int64_t>(static_cast<int8_t *>(data_ptr)[linear]);
case PrimitiveTypeID::u8:
return static_cast<int64_t>(static_cast<uint8_t *>(data_ptr)[linear]);
default:
return obs.observed_value + 1;
return data_ptr == obs.observed_devalloc && prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen;
}
case Obs::ExternalShapeObs: {
if (ctx == nullptr) {
return false;
}
std::vector<int> arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end());
arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY);
arg_indices.push_back(obs.arg_shape_axis);
return static_cast<int64_t>(ctx->get_struct_arg_host<int32_t>(arg_indices)) == obs.observed_value;
}
}
return obs.observed_value + 1;
return false;
}

} // namespace
Expand All @@ -125,8 +70,7 @@ bool AdStackCache::try_size_expr_cache_hit(Program *prog,
}
const auto &entry = it->second;
for (const auto &obs : entry.reads) {
int64_t now = replay_one_observation(obs, prog, ctx);
if (now != obs.observed_value) {
if (!replay_observation_is_fresh(obs, prog, ctx)) {
size_expr_cache_.erase(it);
return false;
}
Expand Down Expand Up @@ -163,8 +107,7 @@ bool AdStackCache::try_max_reducer_cache_hit(uint32_t registry_id,
}
const auto &entry = it->second;
for (const auto &obs : entry.reads) {
int64_t now = replay_one_observation(obs, prog_, ctx);
if (now != obs.observed_value) {
if (!replay_observation_is_fresh(obs, prog_, ctx)) {
max_reducer_cache_.erase(it);
return false;
}
Expand All @@ -178,11 +121,9 @@ void populate_max_reducer_body_observations(std::vector<AdStackCache::SizeExprRe
AdStackCache *cache) {
for (auto &obs : reads) {
if (obs.kind == AdStackCache::SizeExprReadObservation::FieldLoadObs) {
// `FieldLoadObs` from a bound-var-indexed body leaf: snapshot the snode write generation so a subsequent launch
// that has not mutated the SNode replays the cached max via `replay_one_observation`'s gen-fast-skip arm. Same
// sentinel rationale as `ExternalReadObs` below: the recognizer restricts the leaf dtype so an `INT64_MIN`
// recorded value cannot equal a freshly-loaded one on cache miss.
obs.observed_value = std::numeric_limits<int64_t>::min();
// Snapshot the snode write generation so a subsequent launch that has not mutated the SNode replays as a cache
// hit via the gen-counter check in `replay_observation_is_fresh`. `observed_value` is unused for FieldLoadObs
// (gen counter is the sole staleness signal) and left at its default-constructed value.
if (cache != nullptr) {
obs.observed_gen = cache->snode_write_gen(obs.snode_id);
}
Expand All @@ -201,13 +142,9 @@ void populate_max_reducer_body_observations(std::vector<AdStackCache::SizeExprRe
continue;
}
obs.observed_devalloc = it->second;
// Pick an `observed_value` that no in-range ndarray scalar can equal (`INT64_MIN`). The replay code returns
// `obs.observed_value` verbatim when `ndarray_data_gen` still matches the recorded snapshot, so an `INT64_MIN`
// record is a self-equal cache hit. On gen mismatch the replay re-dereferences `data[0]` instead, which (under any
// sub-i64 prim_dt the recognizer admits) widens to an i64 strictly greater than `INT64_MIN` and forces the cache to
// invalidate. The dispatched max itself lives in `MaxReducerCacheEntry::result`; this observation only gates
// whether the cache stays warm.
obs.observed_value = std::numeric_limits<int64_t>::min();
// Snapshot the data-pointer + ndarray data generation so a subsequent launch with the same `DeviceAllocation` and
// an unbumped gen replays as a cache hit. `observed_value` is unused for ExternalReadObs (gen counter is the sole
// staleness signal) and left at its default-constructed value.
if (cache != nullptr) {
obs.observed_gen = cache->ndarray_data_gen(it->second);
}
Expand Down Expand Up @@ -243,8 +180,7 @@ bool AdStackCache::try_spirv_bytecode_cache_hit(Program *prog,
}
const auto &entry = it->second;
for (const auto &obs : entry.reads) {
int64_t now = replay_one_observation(obs, prog, ctx);
if (now != obs.observed_value) {
if (!replay_observation_is_fresh(obs, prog, ctx)) {
spirv_bytecode_cache_.erase(it);
return false;
}
Expand Down
13 changes: 7 additions & 6 deletions quadrants/program/adstack/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ class AdStackCache {
}

// One input read observed during a `evaluate_adstack_size_expr` walk. The cache entry records these so a subsequent
// lookup re-reads the same inputs and compares to `observed_value`; a single mismatch forces a full re-walk.
// `observed_gen` snapshots `snode_write_gen` (FieldLoadObs) or `ndarray_data_gen` (ExternalReadObs) at record
// time. The replay walk uses it as a fast-path short-circuit: if the gen counter has not advanced, the value
// cannot have changed and the dispatch (reader kernel for SNode reads, device-pointer deref for ndarray reads)
// is skipped. ExternalShapeObs reads the args buffer per launch (cheap host memory access), so it does not need
// a gen and leaves this field at 0.
// lookup checks whether each recorded input is still consistent; a single mismatch forces a full re-walk.
// FieldLoadObs / ExternalReadObs use the per-buffer gen counter (`snode_write_gen` / `ndarray_data_gen`) snapshotted
// in `observed_gen` as the sole staleness signal: a gen-counter advance unconditionally invalidates the cache, even
// if the read cells happen to be untouched. The bump invariant is therefore required: every kernel write,
// `Ndarray.write` / `fill`, and `SNodeRwAccessorsBank` writer kernel must bump the matching gen counter for the
// cache to stay correct. `observed_value` is unused for these kinds. ExternalShapeObs has no gen counter (shapes
// live in launch-arg metadata, not buffer content) and falls back to value comparison against `observed_value`.
struct SizeExprReadObservation {
enum Kind : uint8_t { FieldLoadObs, ExternalShapeObs, ExternalReadObs };
Kind kind;
Expand Down
Loading
Loading