From d70fe2c482568b21db9f0f37bf3f14285e69ded9 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Fri, 8 May 2026 16:53:39 +0200 Subject: [PATCH 1/2] [Perf] Adstack: skip max-reducer recognizer on CPU + lift host-eval cap; simplify cache invalidation to gen counter --- docs/source/user_guide/autodiff.md | 10 +- quadrants/codegen/llvm/codegen_llvm.cpp | 16 ++- quadrants/program/adstack/cache.cpp | 128 ++++++------------------ quadrants/program/adstack/cache.h | 13 +-- quadrants/program/adstack/eval.cpp | 120 +++++++++++++++++++--- tests/python/test_adstack.py | 52 ++++++---- 6 files changed, 193 insertions(+), 146 deletions(-) diff --git a/docs/source/user_guide/autodiff.md b/docs/source/user_guide/autodiff.md index 9f436445bc..423e034fc9 100644 --- a/docs/source/user_guide/autodiff.md +++ b/docs/source/user_guide/autodiff.md @@ -465,23 +465,23 @@ Quadrants computes that worst case at launch time - in this example, the max of ### Evaluation paths -The compiler picks one of two evaluation paths to compute the maximum based on the bound expression's structure: +On GPU backends the compiler picks one of two evaluation paths to compute the maximum based on the bound expression's structure; on CPU the sequential path is always taken since the runtime's CPU max-reducer is single-threaded and the parallel dispatch's per-launch setup would be pure overhead: -- **Parallel:** the maximum is computed with a tiny parallel reduction kernel for efficiency. The reducer accepts a common subset of bound expressions: +- **Parallel:** the maximum is computed with a tiny parallel reduction kernel on the GPU for efficiency. The reducer accepts a common subset of bound expressions: - **Integer ndarray or field read** up to 32 bits wide, indexed by literal constants or outer-loop variables: `arr[i, j]`, `field[i]`. - **Shape term**: `arr.shape[k]`. - **Literal integer constant**: `42`. - **Arithmetic combinator**: any `+`, `-`, `*`, `max` of the above. -- **Sequential:** the fallback path, used whenever the parallel path doesn't support the bound expression. Quadrants walks the bound expression one outer-loop iteration at a time on a single thread; the adstack is sized identically, only the upfront cost differs. This path accepts everything the parallel path does, plus: +- **Sequential:** the fallback path, used whenever the parallel path doesn't support the bound expression. Quadrants walks the bound expression one outer-loop iteration at a time on a single thread (host-side on CPU, single-thread on-device kernel on GPU); the adstack is sized identically, only the upfront cost differs. This path accepts everything the parallel path does, plus: - **Arithmetic-indexed read**: `arr[i // 2]`, `arr[i % 4]`. - **Indirect / nested read**: `arr1[arr2[i]]`, `my_field[arr[i]]`. ### Nested loops -Quadrants supports arbitrarily nested loops. When the bound expression itself contains another enclosed loop whose own bound expression must be reduced first, the enclosing bound expression takes the parallel path only if every nested bound expression also fits the parallel-path grammar; otherwise it falls back to the sequential walk. This keeps the runtime from mixing parallel and sequential evaluators inside a single bound expression, which would otherwise force per-iteration kernel launches. +Quadrants supports arbitrarily nested loops. When the bound expression itself contains another enclosed loop whose own bound expression must be reduced first, the enclosing bound expression takes the parallel path only if every nested bound expression is also supported by the parallel path; otherwise it falls back to the sequential walk. This keeps the runtime from mixing parallel and sequential evaluators inside a single bound expression, which would otherwise force per-iteration kernel launches. ### Sequential walk cap -The sequential walk's outer loop is artificially capped at 2^24 = 16 777 216 iterations to keep both the walk time and the read-tracking memory bounded; past that the kernel raises `RuntimeError: ... iteration count ... exceeds the 16777216 guard`. In the example above, the iteration count of the enclosed loop takes the sequential path because of the `i // 2` index, so it would raise at launch if `arr.shape[0] > (1 << 24)`. +The sequential walk's outer loop is artificially capped at 2^24 = 16 777 216 iterations on GPU backends to keep the walk time bounded; past that the kernel raises `RuntimeError: ... iteration count ... exceeds the 16777216 guard`. In the example above, the iteration count of the enclosed loop takes the sequential path because of the `i // 2` index, so it would raise at launch on GPU backends if `arr.shape[0] > (1 << 24)`. To circumvent this limitation, rewrite the bound expression to unlock the parallel path (e.g. precompute `bounds[i] = arr[i // 2]` into a persistent separate buffer, pass `bounds` in as an input, and use `for j in range(bounds[i]):`), or keep the outer loop count below 2^24. diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 6b7a49c91c..132928474a 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -1994,10 +1994,18 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() { current_task->ad_stack.allocas = ad_stack_allocas_info_; current_task->ad_stack.size_exprs = ad_stack_size_exprs_; current_task->ad_stack.bound_expr = ad_stack_static_bound_expr_; - // recognize `MaxOverRange` nodes that the runtime can reduce in parallel via the dedicated max-reducer dispatch - // instead of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order - // as the pre-scan above). - current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_); + // Recognize `MaxOverRange` nodes the runtime can reduce in parallel via the dedicated max-reducer dispatch instead + // of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order as the + // pre-scan above). Skip on CPU: `runtime_eval_adstack_max_reduce_serial` walks single-threaded just like the host + // evaluator's `MaxOverRange` loop in `program/adstack/eval.cpp`, so the dispatch's per-launch setup overhead + // (params blob encode, body bytecode encode, observation bookkeeping, JIT call) is pure cost without compute + // parallelism to offset it - measured ~28 % wallclock regression on the rigid-step CPU bench. The host evaluator + // handles every iteration count up to its own cap (raised to UINT32_MAX on CPU in `eval.cpp`) so above-cap shapes + // still resolve correctly. On CUDA / AMDGPU the parallel reducer is the whole point of the dispatch and the + // recognizer stays active. + if (!arch_is_cpu(compile_config.arch)) { + current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_); + } // Snodes the task body mutates. Persisted on `OffloadedTask::snode_writes` so the LLVM // launcher can invalidate the per-task adstack metadata cache when a kernel that runs in // between mutated a SNode an enclosing `size_expr::FieldLoad` reads. Mirrors the SPIR-V diff --git a/quadrants/program/adstack/cache.cpp b/quadrants/program/adstack/cache.cpp index f8f3b9aaee..c32ee42712 100644 --- a/quadrants/program/adstack/cache.cpp +++ b/quadrants/program/adstack/cache.cpp @@ -20,97 +20,42 @@ namespace quadrants::lang { namespace { -// Read the input that `obs` describes against the live state and `ctx`. Caller compares the result to -// `obs.observed_value` to decide whether the cached `SizeExprCacheEntry` is still valid. Each `obs.kind` -// mirrors the corresponding leaf in `evaluate_field_load` / `evaluate_external_tensor_shape` / -// `evaluate_external_tensor_read`. -int64_t replay_one_observation(const AdStackCache::SizeExprReadObservation &obs, - Program *prog, - LaunchContextBuilder *ctx) { +// Decide whether the input that `obs` describes is still consistent with the recorded state. Returns true iff the +// cached `SizeExprCacheEntry` is still valid for this observation. FieldLoadObs / ExternalReadObs use the per-buffer +// gen counter (`snode_write_gen` / `ndarray_data_gen`) as the sole staleness signal: a gen-counter advance forces a +// re-walk regardless of whether the read cells themselves changed. ExternalShapeObs has no gen counter (shapes are +// launch-arg metadata, not buffer content), so it falls back to value comparison against `observed_value`. +bool replay_observation_is_fresh(const AdStackCache::SizeExprReadObservation &obs, + Program *prog, + LaunchContextBuilder *ctx) { using Obs = AdStackCache::SizeExprReadObservation; switch (obs.kind) { - case Obs::FieldLoadObs: { - // Gen-counter fast skip: when no kernel has bumped this SNode's write generation since record time, the - // underlying field value cannot have changed and we can return the recorded `observed_value` without dispatching - // a reader kernel. The dispatch is the dominant per-launch cost on the hot path for steady-state reverse-mode - // loops with stable bounds. - if (prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen) { - return obs.observed_value; - } - // Max-reducer body FieldLoadObs (bound-var-indexed leaves) records `indices = {}` since the body is evaluated at - // every cross-product point and there is no single canonical index to re-read. The gen counter is the only valid - // staleness signal in that mode; a gen mismatch unconditionally invalidates the cache. - if (obs.indices.empty()) { - return obs.observed_value + 1; - } - int64_t v = read_field_with_launch_cache(obs.snode_id, obs.indices, prog); - if (v == std::numeric_limits::min()) { - return obs.observed_value + 1; // force a mismatch if SNode disappeared - } - return v; - } - case Obs::ExternalShapeObs: { - if (ctx == nullptr) { - return obs.observed_value + 1; - } - std::vector arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end()); - arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - arg_indices.push_back(obs.arg_shape_axis); - return static_cast(ctx->get_struct_arg_host(arg_indices)); - } + case Obs::FieldLoadObs: + return prog != nullptr && prog->adstack_cache().snode_write_gen(obs.snode_id) == obs.observed_gen; case Obs::ExternalReadObs: { - if (ctx == nullptr || obs.arg_id_path.empty()) { - return obs.observed_value + 1; + if (ctx == nullptr || obs.arg_id_path.empty() || prog == nullptr) { + return false; } int arg_id = obs.arg_id_path[0]; ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; auto it = ctx->array_ptrs.find(key); if (it == ctx->array_ptrs.end()) { - return obs.observed_value + 1; + return false; } void *data_ptr = it->second; - // Gen-counter fast skip: when the data pointer is the same `DeviceAllocation *` we observed at record - // time AND its data generation has not been bumped since (no kernel write, no host-side `Ndarray.write` - // / `fill`), the underlying scalar cannot have changed and we can return the recorded value without - // dereferencing the device pointer (which on GPU would be a DtoH copy, on CPU a host load). - if (prog != nullptr && data_ptr == obs.observed_devalloc && - prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen) { - return obs.observed_value; - } - int64_t linear = 0; - int64_t stride = 1; - for (std::size_t i = obs.indices.size(); i > 0; --i) { - linear += static_cast(obs.indices[i - 1]) * stride; - if (i - 1 > 0) { - std::vector sh_idx(obs.arg_id_path.begin(), obs.arg_id_path.end()); - sh_idx.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); - sh_idx.push_back(static_cast(i - 1)); - stride *= static_cast(ctx->get_struct_arg_host(sh_idx)); - } - } - switch (static_cast(obs.prim_dt)) { - case PrimitiveTypeID::i32: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i64: - return static_cast(data_ptr)[linear]; - case PrimitiveTypeID::u32: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u64: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i16: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u16: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::i8: - return static_cast(static_cast(data_ptr)[linear]); - case PrimitiveTypeID::u8: - return static_cast(static_cast(data_ptr)[linear]); - default: - return obs.observed_value + 1; + return data_ptr == obs.observed_devalloc && prog->adstack_cache().ndarray_data_gen(data_ptr) == obs.observed_gen; + } + case Obs::ExternalShapeObs: { + if (ctx == nullptr) { + return false; } + std::vector arg_indices(obs.arg_id_path.begin(), obs.arg_id_path.end()); + arg_indices.push_back(TypeFactory::SHAPE_POS_IN_NDARRAY); + arg_indices.push_back(obs.arg_shape_axis); + return static_cast(ctx->get_struct_arg_host(arg_indices)) == obs.observed_value; } } - return obs.observed_value + 1; + return false; } } // namespace @@ -125,8 +70,7 @@ bool AdStackCache::try_size_expr_cache_hit(Program *prog, } const auto &entry = it->second; for (const auto &obs : entry.reads) { - int64_t now = replay_one_observation(obs, prog, ctx); - if (now != obs.observed_value) { + if (!replay_observation_is_fresh(obs, prog, ctx)) { size_expr_cache_.erase(it); return false; } @@ -163,8 +107,7 @@ bool AdStackCache::try_max_reducer_cache_hit(uint32_t registry_id, } const auto &entry = it->second; for (const auto &obs : entry.reads) { - int64_t now = replay_one_observation(obs, prog_, ctx); - if (now != obs.observed_value) { + if (!replay_observation_is_fresh(obs, prog_, ctx)) { max_reducer_cache_.erase(it); return false; } @@ -178,11 +121,9 @@ void populate_max_reducer_body_observations(std::vector::min(); + // Snapshot the snode write generation so a subsequent launch that has not mutated the SNode replays as a cache + // hit via the gen-counter check in `replay_observation_is_fresh`. `observed_value` is unused for FieldLoadObs + // (gen counter is the sole staleness signal) and left at its default-constructed value. if (cache != nullptr) { obs.observed_gen = cache->snode_write_gen(obs.snode_id); } @@ -201,13 +142,9 @@ void populate_max_reducer_body_observations(std::vectorsecond; - // Pick an `observed_value` that no in-range ndarray scalar can equal (`INT64_MIN`). The replay code returns - // `obs.observed_value` verbatim when `ndarray_data_gen` still matches the recorded snapshot, so an `INT64_MIN` - // record is a self-equal cache hit. On gen mismatch the replay re-dereferences `data[0]` instead, which (under any - // sub-i64 prim_dt the recognizer admits) widens to an i64 strictly greater than `INT64_MIN` and forces the cache to - // invalidate. The dispatched max itself lives in `MaxReducerCacheEntry::result`; this observation only gates - // whether the cache stays warm. - obs.observed_value = std::numeric_limits::min(); + // Snapshot the data-pointer + ndarray data generation so a subsequent launch with the same `DeviceAllocation` and + // an unbumped gen replays as a cache hit. `observed_value` is unused for ExternalReadObs (gen counter is the sole + // staleness signal) and left at its default-constructed value. if (cache != nullptr) { obs.observed_gen = cache->ndarray_data_gen(it->second); } @@ -243,8 +180,7 @@ bool AdStackCache::try_spirv_bytecode_cache_hit(Program *prog, } const auto &entry = it->second; for (const auto &obs : entry.reads) { - int64_t now = replay_one_observation(obs, prog, ctx); - if (now != obs.observed_value) { + if (!replay_observation_is_fresh(obs, prog, ctx)) { spirv_bytecode_cache_.erase(it); return false; } diff --git a/quadrants/program/adstack/cache.h b/quadrants/program/adstack/cache.h index 33024dbfaf..db8322e491 100644 --- a/quadrants/program/adstack/cache.h +++ b/quadrants/program/adstack/cache.h @@ -38,12 +38,13 @@ class AdStackCache { } // One input read observed during a `evaluate_adstack_size_expr` walk. The cache entry records these so a subsequent - // lookup re-reads the same inputs and compares to `observed_value`; a single mismatch forces a full re-walk. - // `observed_gen` snapshots `snode_write_gen` (FieldLoadObs) or `ndarray_data_gen` (ExternalReadObs) at record - // time. The replay walk uses it as a fast-path short-circuit: if the gen counter has not advanced, the value - // cannot have changed and the dispatch (reader kernel for SNode reads, device-pointer deref for ndarray reads) - // is skipped. ExternalShapeObs reads the args buffer per launch (cheap host memory access), so it does not need - // a gen and leaves this field at 0. + // lookup checks whether each recorded input is still consistent; a single mismatch forces a full re-walk. + // FieldLoadObs / ExternalReadObs use the per-buffer gen counter (`snode_write_gen` / `ndarray_data_gen`) snapshotted + // in `observed_gen` as the sole staleness signal: a gen-counter advance unconditionally invalidates the cache, even + // if the read cells happen to be untouched. The bump invariant is therefore required: every kernel write, + // `Ndarray.write` / `fill`, and `SNodeRwAccessorsBank` writer kernel must bump the matching gen counter for the + // cache to stay correct. `observed_value` is unused for these kinds. ExternalShapeObs has no gen counter (shapes + // live in launch-arg metadata, not buffer content) and falls back to value comparison against `observed_value`. struct SizeExprReadObservation { enum Kind : uint8_t { FieldLoadObs, ExternalShapeObs, ExternalReadObs }; Kind kind; diff --git a/quadrants/program/adstack/eval.cpp b/quadrants/program/adstack/eval.cpp index 064d14a121..faa0fa5d13 100644 --- a/quadrants/program/adstack/eval.cpp +++ b/quadrants/program/adstack/eval.cpp @@ -1,6 +1,8 @@ #include "quadrants/program/adstack/eval.h" #include +#include +#include #include #include #include @@ -10,6 +12,8 @@ #include "quadrants/common/logging.h" #include "quadrants/ir/snode.h" #include "quadrants/ir/type.h" +#include "quadrants/program/program.h" +#include "quadrants/rhi/arch.h" #include "quadrants/ir/type_factory.h" #include "quadrants/ir/type_utils.h" #include "quadrants/program/launch_context_builder.h" @@ -80,9 +84,9 @@ int64_t evaluate_field_load(const SerializedSizeExprNode &node, obs.indices = std::move(indices); obs.arg_shape_axis = 0; obs.prim_dt = 0; - obs.observed_value = v; - // Snapshot the SNode's write gen so the next replay can fast-skip when no kernel has written this SNode - // since record time (the dominant case for a steady-state reverse-mode loop with stable bounds). + // Snapshot the SNode's write gen so the next replay can short-circuit when no kernel has written this SNode since + // record time (the dominant case for a steady-state reverse-mode loop with stable bounds). The gen counter is the + // sole staleness signal for FieldLoadObs; `observed_value` is unused for this kind. obs.observed_gen = prog->adstack_cache().snode_write_gen(node.snode_id); reads->push_back(std::move(obs)); } @@ -177,12 +181,11 @@ int64_t evaluate_external_tensor_read(const SerializedSizeExprNode &node, obs.arg_id_path = node.arg_id_path; obs.arg_shape_axis = 0; obs.prim_dt = static_cast(prim_dt); - obs.observed_value = v; obs.observed_devalloc = data_ptr; if (prog != nullptr) { - // Snapshot the ndarray's data gen so the next replay can fast-skip when no kernel / Ndarray API write - // has touched the underlying buffer since record time. Mirrors the FieldLoad fast-skip; covers the same - // steady-state hot path for ndarray-bounded reverse-mode loops. + // Snapshot the ndarray's data gen so the next replay can short-circuit when no kernel / Ndarray API write has + // touched the underlying buffer since record time. Mirrors the FieldLoad fast-skip; the gen counter is the sole + // staleness signal for ExternalReadObs (`observed_value` is unused for this kind). obs.observed_gen = prog->adstack_cache().ndarray_data_gen(data_ptr); } reads->push_back(std::move(obs)); @@ -218,6 +221,84 @@ int64_t evaluate_external_tensor_shape(const SerializedSizeExprNode &node, Launc return v; } +// Enumerate one observation per static FieldLoad / ExternalTensorRead / ExternalTensorShape leaf in the subtree rooted +// at `node_idx`, regardless of whether the evaluator would visit it during a `MaxOverRange` walk. A nested +// `MaxOverRange` whose `end <= begin` for some outer iterations does not visit its body, but its leaves still need to +// be registered with the cache so a subsequent launch where that range becomes non-empty correctly invalidates on a +// buffer mutation. Used by the `MaxOverRange` arm of `evaluate_node` to pre-walk the body once before the per-iteration +// evaluation runs with `reads = nullptr`, so the observation count stays at O(unique body leaves) regardless of N. +void enumerate_static_observations(const SerializedSizeExpr &expr, + int32_t node_idx, + Program *prog, + LaunchContextBuilder *ctx, + ReadSink *reads) { + if (reads == nullptr) { + return; + } + QD_ASSERT_INFO(node_idx >= 0 && static_cast(node_idx) < expr.nodes.size(), + "SerializedSizeExpr enumerate_static_observations node_idx {} out of bounds (size={})", node_idx, + expr.nodes.size()); + const auto &node = expr.nodes[node_idx]; + switch (static_cast(node.kind)) { + case SizeExpr::Kind::FieldLoad: { + AdStackCache::SizeExprReadObservation obs; + obs.kind = AdStackCache::SizeExprReadObservation::FieldLoadObs; + obs.snode_id = node.snode_id; + obs.arg_shape_axis = 0; + obs.prim_dt = 0; + if (prog != nullptr) { + obs.observed_gen = prog->adstack_cache().snode_write_gen(node.snode_id); + } + reads->push_back(std::move(obs)); + break; + } + case SizeExpr::Kind::ExternalTensorRead: { + if (ctx == nullptr || node.arg_id_path.empty()) { + break; + } + int arg_id = node.arg_id_path[0]; + ArgArrayPtrKey key{arg_id, TypeFactory::DATA_PTR_POS_IN_NDARRAY}; + auto it = ctx->array_ptrs.find(key); + if (it == ctx->array_ptrs.end()) { + break; + } + AdStackCache::SizeExprReadObservation obs; + obs.kind = AdStackCache::SizeExprReadObservation::ExternalReadObs; + obs.snode_id = 0; + obs.arg_id_path = node.arg_id_path; + obs.arg_shape_axis = 0; + obs.prim_dt = static_cast(node.const_value); + obs.observed_devalloc = it->second; + if (prog != nullptr) { + obs.observed_gen = prog->adstack_cache().ndarray_data_gen(it->second); + } + reads->push_back(std::move(obs)); + break; + } + case SizeExpr::Kind::ExternalTensorShape: + // Reuse the regular shape evaluator: it pushes the obs with the current shape value, which is what the + // value-comparison replay path needs since shapes have no gen counter. The shape value depends only on the + // launch context, not on `bound_vars`, so recording it outside the iteration loop is safe. + evaluate_external_tensor_shape(node, ctx, reads); + break; + case SizeExpr::Kind::Add: + case SizeExpr::Kind::Sub: + case SizeExpr::Kind::Mul: + case SizeExpr::Kind::Max: + enumerate_static_observations(expr, node.operand_a, prog, ctx, reads); + enumerate_static_observations(expr, node.operand_b, prog, ctx, reads); + break; + case SizeExpr::Kind::MaxOverRange: + enumerate_static_observations(expr, node.operand_a, prog, ctx, reads); + enumerate_static_observations(expr, node.operand_b, prog, ctx, reads); + enumerate_static_observations(expr, node.body_node_idx, prog, ctx, reads); + break; + case SizeExpr::Kind::Const: + case SizeExpr::Kind::BoundVariable: + break; + } +} + int64_t evaluate_node(const SerializedSizeExpr &expr, int32_t node_idx, std::unordered_map &bound_vars, @@ -248,11 +329,18 @@ int64_t evaluate_node(const SerializedSizeExpr &expr, case SizeExpr::Kind::MaxOverRange: { int64_t begin = evaluate_node(expr, node.operand_a, bound_vars, prog, ctx, reads); int64_t end = evaluate_node(expr, node.operand_b, bound_vars, prog, ctx, reads); - // Guard against pathological trip counts. The evaluator walks `[begin, end)` linearly and re-evaluates the - // body at every i; a range of several million would stall the launch hot path for seconds. Real reverse-mode - // trip counts sit well below this cap (a few hundred to a few thousand in practice); anything above is - // almost certainly a pre-pass grammar bug the user should file, and a clear QD_ERROR beats a silent hang. - constexpr int64_t kMaxOverRangeIterations = int64_t{1} << 24; + // Guard against pathological trip counts. The evaluator walks `[begin, end)` linearly and re-evaluates the body + // at every i; on Metal / Vulkan / CUDA / AMDGPU the recognizer-captured `MaxOverRange` shapes are dispatched in + // parallel by the max-reducer and substituted to a `Const` before the host evaluator walks the tree, so any + // shape that lands here above the `1<<24` cap is out-of-grammar and a clear QD_ERROR beats a silent hang. On + // CPU the recognizer is intentionally skipped (see the matching `arch_is_cpu` gate in + // `codegen/llvm/codegen_llvm.cpp::finalize_offloaded_task_function`: the CPU max-reducer walks single-threaded + // just like this loop, so the dispatch's per-launch setup overhead is pure cost), and the CPU max-reducer + // itself has no equivalent cap (`runtime_eval_adstack_max_reduce_serial` at + // `runtime/llvm/runtime_module/adstack_runtime.cpp:622-636` iterates `total_length` unconditionally), so this + // evaluator must match - lift the cap to `UINT32_MAX` on CPU so legitimate above-cap workloads still complete. + const bool prog_is_cpu = (prog != nullptr) && arch_is_cpu(prog->compile_config().arch); + const int64_t kMaxOverRangeIterations = prog_is_cpu ? int64_t{UINT32_MAX} : (int64_t{1} << 24); QD_ERROR_IF(end > begin && end - begin > kMaxOverRangeIterations, "SerializedSizeExpr MaxOverRange iteration count {} exceeds the {} guard; refusing to enumerate. " "Shrink the enclosing reverse-mode loop or restructure the `SizeExpr` source kernel.", @@ -264,9 +352,15 @@ int64_t evaluate_node(const SerializedSizeExpr &expr, auto prev_it = bound_vars.find(node.var_id); bool had_prev = prev_it != bound_vars.end(); int64_t prev_val = had_prev ? prev_it->second : 0; + // Pre-walk the body subtree once to register every static read leaf with the cache (see + // `enumerate_static_observations` above). The per-iteration evaluation below runs with `reads = nullptr`, so it + // does not push observations during the loop, keeping memory bounded at O(unique body leaves) regardless of N. + // The structural enumeration is independent of which iterations execute, so a nested `MaxOverRange` whose body + // is conditionally visited (e.g., empty inner range on some outer iterations) still gets its leaves registered. + enumerate_static_observations(expr, node.body_node_idx, prog, ctx, reads); for (int64_t i = begin; i < end; ++i) { bound_vars[node.var_id] = i; - int64_t v = evaluate_node(expr, node.body_node_idx, bound_vars, prog, ctx, reads); + int64_t v = evaluate_node(expr, node.body_node_idx, bound_vars, prog, ctx, nullptr); if (v > result) { result = v; } diff --git a/tests/python/test_adstack.py b/tests/python/test_adstack.py index 42d657a568..b958644e92 100644 --- a/tests/python/test_adstack.py +++ b/tests/python/test_adstack.py @@ -4735,21 +4735,25 @@ def compute(x: qd.types.NDArray, selector: qd.types.NDArray, out: qd.types.NDArr "above_cap_arith_combine", ], ) -@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +@test_utils.test(arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], require=qd.extension.adstack, cfg_optimization=False) def test_max_reducer_pins_stride_for_oversized_axis(shape, body_kind): - # A reverse-mode kernel with a parallel-for over an arbitrarily large ndarray axis and an inner range-for bound to a - # recognizer-accepted trip-count expression sizes its adstack at launch time and computes the right gradient, - # without the per-task sizer's `1<<24` cap firing. + # A reverse-mode kernel with a parallel-for over an arbitrarily large ndarray axis and an inner range-for bound to + # a recognizer-accepted trip-count expression sizes its adstack at launch time and computes the right gradient, + # without the per-task sizer's `1<<24` cap firing. GPU only: the max-reducer dispatch is GPU-specific - the host + # evaluator handles equivalent shapes on CPU. # # Internal details: the kernel lowers to `MaxOverRange(0, a.shape[0], )` in the per-stack `SizeExpr`. # `recognize_adstack_max_reducer_specs` captures the spec; the launcher dispatches the parallel max-reducer before - # the per-task sizer walks the tree; `substitute_precomputed_max_over_range` rewrites the captured `MaxOverRange` to - # `Const`. The above-cap variants place the only non-zero cell at `arr_np[-1] = N_X` so heap-stride correctness + # the per-task sizer walks the tree; `substitute_precomputed_max_over_range` rewrites the captured `MaxOverRange` + # to `Const`. The above-cap variants place the only non-zero cell at `arr_np[-1] = N_X` so heap-stride correctness # depends on the dispatch walking every element of the axis rather than relying on a partial host-eval walk. The # `shape_in_body` / `field_in_body` variants additionally pin that closed leaves (`ExternalTensorShape`, # `FieldLoad`) host-fold to `kConst` at encode time and never reach the device interpreter; `arith_combine` # exercises every binary combinator (`Add`, `Sub`, `Mul`, `Max`) and `Const` leaf in a single body expression that - # algebraically reduces to `a[i_e]`. + # algebraically reduces to `a[i_e]`. The CPU codegen gate lives in + # `codegen/llvm/codegen_llvm.cpp::finalize_offloaded_task_function`; the lifted host-eval cap lives in + # `program/adstack/eval.cpp::evaluate_node`. On CPU `_get_max_reducer_dispatch_count` stays at 0 (no dispatch + # fires), which is why this test pins it on GPU arches only. N_X = 4 arr_np = np.zeros(shape, dtype=np.int32) arr_np[-1] = N_X @@ -4813,17 +4817,18 @@ def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): assert x.grad[k] == pytest.approx(2 * 0.1, rel=1e-5) -@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +@test_utils.test(arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], require=qd.extension.adstack, cfg_optimization=False) def test_max_reducer_dispatch_counts_advance_on_input_mutation(): - # Pins the dispatch + cache invalidation pipeline. The first launch must fire at least one max-reducer dispatch (the - # kernel's `MaxOverRange(0, a.shape[0], a[var])` matches the recognizer grammar so the recognizer captures the spec; - # the launcher dispatches once and bumps `Program.max_reducer_dispatch_count`). A subsequent host mutation of the - # gating ndarray must bump `ndarray_data_gen` and force the next launch to re-dispatch, advancing the counter beyond - # its post-first-launch value. Steady-state cache short-circuit on an unchanged ndarray is backend-dependent (the - # CPU launcher's `set_host_accessible_ndarray_ptrs` path converts qd.ndarray reads to `kNone` semantics and - # `bump_writes_for_kernel_llvm` then bumps the gen on every read; the SPIR-V launchers preserve the qd.ndarray - # dev-alloc-type and only bump on writes), so this test asserts only the mutation-triggers-redispatch contract that - # holds uniformly. + # Pins the dispatch + cache invalidation pipeline. The first launch must fire at least one max-reducer dispatch + # (the kernel's `MaxOverRange(0, a.shape[0], a[var])` matches the recognizer grammar so the recognizer captures + # the spec; the launcher dispatches once and bumps `Program.max_reducer_dispatch_count`). A subsequent host + # mutation of the gating ndarray must bump `ndarray_data_gen` and force the next launch to re-dispatch, advancing + # the counter beyond its post-first-launch value. Steady-state cache short-circuit on an unchanged ndarray is + # backend-dependent (the CPU launcher's `set_host_accessible_ndarray_ptrs` path converts qd.ndarray reads to + # `kNone` semantics and `bump_writes_for_kernel_llvm` then bumps the gen on every read; the SPIR-V launchers + # preserve the qd.ndarray dev-alloc-type and only bump on writes), so this test asserts only the + # mutation-triggers-redispatch contract that holds uniformly. GPU only: the max-reducer dispatch is GPU-specific - + # the host evaluator handles equivalent shapes on CPU. N = 4 x = qd.field(qd.f32, shape=(N,), needs_grad=True) @@ -4921,13 +4926,14 @@ def compute(): "arr_bv_indexed_by_field_load", ], ) -@test_utils.test(require=qd.extension.adstack) +@test_utils.test(arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], require=qd.extension.adstack) def test_max_reducer_field_load_bound_var_dispatch(body_kind): # A reverse-mode kernel whose inner range-for trip count reads a `qd.field` indexed by the outer chain variable # captures via the parallel max-reducer dispatch and produces the analytical gradient. The body-shape # parametrization exercises every supported composition: bound-var FieldLoad on its own, mixed with bound-var ETR # via `Add` / `Max`, combined with `Const` / arithmetic, and the nested-load worst-case form (`field[field[i]]` / - # `arr[field[i]]`). + # `arr[field[i]]`). GPU only: the max-reducer dispatch is GPU-specific - the host evaluator handles equivalent + # shapes on CPU. # # Internal details: each variant lowers to `MaxOverRange(0, M, body)` where `body` is bound-var-indexed # `FieldLoad(field_a, [bound_var])` or a recognizer-accepted composition that includes one. The relaxed @@ -5028,10 +5034,11 @@ def compute(a: qd.types.ndarray(dtype=qd.i32, ndim=1)): assert x.grad[k] == pytest.approx(expected, rel=1e-5) -@test_utils.test(require=qd.extension.adstack) +@test_utils.test(arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], require=qd.extension.adstack) def test_max_reducer_field_load_bound_var_cache_invalidates_on_snode_mutation(): # A reverse-mode kernel whose inner trip count reads a `qd.field` indexed by the outer chain variable redispatches - # the max-reducer when the gating field is mutated between launches. + # the max-reducer when the gating field is mutated between launches. GPU only: the max-reducer dispatch is + # GPU-specific - the host evaluator handles equivalent shapes on CPU. # # Internal details: the encoder emits a `kFieldLoad` device node and pushes a `FieldLoadObs` carrying the snode id # and the live `snode_write_gen` snapshot. On the second launch's `try_max_reducer_cache_hit`, @@ -5087,10 +5094,11 @@ def compute(): assert prog._get_max_reducer_dispatch_count() > pre_mutation -@test_utils.test(require=qd.extension.adstack, cfg_optimization=False) +@test_utils.test(arch=[qd.cuda, qd.amdgpu, qd.vulkan, qd.metal], require=qd.extension.adstack, cfg_optimization=False) def test_above_cap_out_of_grammar_kernel_raises(): # A reverse-mode kernel whose inner `range(...)` trip count is bound to an out-of-grammar `MaxOverRange` body and # whose iteration count exceeds the `1<<24` adstack-sizer cap surfaces a `QuadrantsAssertionError` at `qd.sync()`. + # GPU only: on CPU the host-eval cap is lifted to UINT32_MAX, so a shape of `(1<<24)+1` resolves without raising. # # Internal details: the recognizer's body grammar accepts only `Const / ExternalTensorRead / Add / Sub / Mul / Max # / ExternalTensorShape / FieldLoad(literal-or-bound-var indices)`, and `max_reducer_body_is_recognizable` further From 93b0dbc9dd75c0e7ae79a2cc0fe224c8a1d69963 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Fri, 8 May 2026 18:37:47 +0200 Subject: [PATCH 2/2] [Refactor] Adstack max-reducer: drop dead CPU `_serial` path; rename `_parallel` to `runtime_eval_adstack_max_reduce` --- docs/source/user_guide/autodiff.md | 4 +- quadrants/codegen/llvm/codegen_llvm.cpp | 12 ++--- quadrants/program/adstack/eval.cpp | 9 ++-- .../llvm/adstack_lazy_claim/bound_eval.cpp | 42 +++++++-------- .../runtime/llvm/llvm_runtime_executor.h | 15 +++--- .../llvm/runtime_module/adstack_runtime.cpp | 53 ++----------------- .../llvm/runtime_module/adstack_runtime.h | 22 +++----- 7 files changed, 49 insertions(+), 108 deletions(-) diff --git a/docs/source/user_guide/autodiff.md b/docs/source/user_guide/autodiff.md index 423e034fc9..dfbe1c5750 100644 --- a/docs/source/user_guide/autodiff.md +++ b/docs/source/user_guide/autodiff.md @@ -465,9 +465,9 @@ Quadrants computes that worst case at launch time - in this example, the max of ### Evaluation paths -On GPU backends the compiler picks one of two evaluation paths to compute the maximum based on the bound expression's structure; on CPU the sequential path is always taken since the runtime's CPU max-reducer is single-threaded and the parallel dispatch's per-launch setup would be pure overhead: +The compiler picks one of two evaluation paths to compute the maximum based on the backend and the bound expression's structure: -- **Parallel:** the maximum is computed with a tiny parallel reduction kernel on the GPU for efficiency. The reducer accepts a common subset of bound expressions: +- **Parallel (GPU only):** the maximum is computed with a tiny parallel reduction kernel for efficiency. The reducer accepts a common subset of bound expressions: - **Integer ndarray or field read** up to 32 bits wide, indexed by literal constants or outer-loop variables: `arr[i, j]`, `field[i]`. - **Shape term**: `arr.shape[k]`. - **Literal integer constant**: `42`. diff --git a/quadrants/codegen/llvm/codegen_llvm.cpp b/quadrants/codegen/llvm/codegen_llvm.cpp index 132928474a..5b9a30bb15 100644 --- a/quadrants/codegen/llvm/codegen_llvm.cpp +++ b/quadrants/codegen/llvm/codegen_llvm.cpp @@ -1996,13 +1996,11 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() { current_task->ad_stack.bound_expr = ad_stack_static_bound_expr_; // Recognize `MaxOverRange` nodes the runtime can reduce in parallel via the dedicated max-reducer dispatch instead // of letting the per-thread sizer enumerate. Indexing matches `ad_stack_size_exprs_` (same iteration order as the - // pre-scan above). Skip on CPU: `runtime_eval_adstack_max_reduce_serial` walks single-threaded just like the host - // evaluator's `MaxOverRange` loop in `program/adstack/eval.cpp`, so the dispatch's per-launch setup overhead - // (params blob encode, body bytecode encode, observation bookkeeping, JIT call) is pure cost without compute - // parallelism to offset it - measured ~28 % wallclock regression on the rigid-step CPU bench. The host evaluator - // handles every iteration count up to its own cap (raised to UINT32_MAX on CPU in `eval.cpp`) so above-cap shapes - // still resolve correctly. On CUDA / AMDGPU the parallel reducer is the whole point of the dispatch and the - // recognizer stays active. + // pre-scan above). Skip on CPU: the host evaluator's `MaxOverRange` loop in `program/adstack/eval.cpp` does the + // same serial walk, and dispatching the runtime helper would only add per-launch setup cost (params blob encode, + // body bytecode encode, observation bookkeeping, JIT call) with no compute parallelism to amortize. The host + // evaluator handles every iteration count up to its own cap (`UINT32_MAX` on CPU; see `eval.cpp`). On CUDA / + // AMDGPU the parallel reducer is the whole point of the dispatch and the recognizer stays active. if (!arch_is_cpu(compile_config.arch)) { current_task->ad_stack.max_reducer_specs = recognize_adstack_max_reducer_specs(ad_stack_size_exprs_); } diff --git a/quadrants/program/adstack/eval.cpp b/quadrants/program/adstack/eval.cpp index faa0fa5d13..23191f9eda 100644 --- a/quadrants/program/adstack/eval.cpp +++ b/quadrants/program/adstack/eval.cpp @@ -334,11 +334,10 @@ int64_t evaluate_node(const SerializedSizeExpr &expr, // parallel by the max-reducer and substituted to a `Const` before the host evaluator walks the tree, so any // shape that lands here above the `1<<24` cap is out-of-grammar and a clear QD_ERROR beats a silent hang. On // CPU the recognizer is intentionally skipped (see the matching `arch_is_cpu` gate in - // `codegen/llvm/codegen_llvm.cpp::finalize_offloaded_task_function`: the CPU max-reducer walks single-threaded - // just like this loop, so the dispatch's per-launch setup overhead is pure cost), and the CPU max-reducer - // itself has no equivalent cap (`runtime_eval_adstack_max_reduce_serial` at - // `runtime/llvm/runtime_module/adstack_runtime.cpp:622-636` iterates `total_length` unconditionally), so this - // evaluator must match - lift the cap to `UINT32_MAX` on CPU so legitimate above-cap workloads still complete. + // `codegen/llvm/codegen_llvm.cpp::finalize_offloaded_task_function`: the runtime max-reducer would do the same + // serial walk this loop already does, so the dispatch's per-launch setup overhead is pure cost). With the + // recognizer skipped, every CPU `MaxOverRange` lands here, and observation memory is bounded by the + // structural pre-walk above so lifting the cap to `UINT32_MAX` is memory-safe. const bool prog_is_cpu = (prog != nullptr) && arch_is_cpu(prog->compile_config().arch); const int64_t kMaxOverRangeIterations = prog_is_cpu ? int64_t{UINT32_MAX} : (int64_t{1} << 24); QD_ERROR_IF(end > begin && end - begin > kMaxOverRangeIterations, diff --git a/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp index 35432be613..b8064956c5 100644 --- a/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp +++ b/quadrants/runtime/llvm/adstack_lazy_claim/bound_eval.cpp @@ -542,11 +542,11 @@ std::unordered_map LlvmRuntimeExecutor::dispatch_max_reducers void *outputs_dev_ptr = get_device_alloc_info_ptr(*adstack_max_reducer_outputs_alloc_); copy_h2d(runtime_adstack_max_reducer_outputs_field_ptr_, &outputs_dev_ptr, sizeof(void *)); - // GPU parallel reducer: seed every round-local output slot with INT64_MIN so the per-thread `atomic_max_i64` - // reductions inside `runtime_eval_adstack_max_reduce_parallel` can publish the cross-product max into a known - // sentinel. CPU's `_serial` variant writes the slot directly inside the kernel (no host-side seeding needed). - const bool use_gpu_parallel_reducer = config_.arch == Arch::cuda || config_.arch == Arch::amdgpu; - if (use_gpu_parallel_reducer) { + // Seed every round-local output slot with INT64_MIN so the per-thread `atomic_max_i64` reductions inside + // `runtime_eval_adstack_max_reduce` can publish the cross-product max into a known sentinel. The + // recognizer is skipped on CPU (see the `arch_is_cpu` gate in `codegen/llvm/codegen_llvm.cpp`), so this dispatch + // loop only runs on CUDA / AMDGPU and the parallel reducer is the only variant invoked. + { std::vector sentinel_slots(level_dispatch.size(), static_cast(0x8000000000000000ll)); copy_h2d(outputs_dev_ptr, sentinel_slots.data(), sentinel_slots.size() * sizeof(int64_t)); } @@ -594,26 +594,20 @@ std::unordered_map LlvmRuntimeExecutor::dispatch_max_reducers void *bytecode_dev_ptr = get_device_alloc_info_ptr(*adstack_max_reducer_bytecode_alloc_); copy_h2d(bytecode_dev_ptr, pending[k].body_bytecode.data(), needed_bytecode_bytes); - if (use_gpu_parallel_reducer) { - // Grid-strided parallel reducer. Cap `grid_dim` at the launcher's saturating grid_dim so the dispatch - // shares the heap-row budget the rest of the launcher uses; `block_dim` matches the codegen default. The - // grid-stride loop handles arbitrary `total_length` so under-dispatching is harmless. - std::uint64_t cross_product = 1; - for (std::size_t a = 0; a < pending[k].params.num_axes; ++a) { - cross_product *= static_cast(pending[k].params.per_axis_length[a]); - } - const std::size_t block_dim = static_cast(std::max(1, config_.max_block_dim)); - const std::size_t needed_threads = std::max(1, static_cast(cross_product)); - const std::size_t grid_dim_cap = static_cast(std::max(1, config_.saturating_grid_dim)); - const std::size_t grid_dim = std::min(grid_dim_cap, (needed_threads + block_dim - 1) / block_dim); - runtime_jit->launch( - "runtime_eval_adstack_max_reduce_parallel", grid_dim, block_dim, 0, llvm_runtime_, - runtime_context_ptr_for_reducer, params_dev_ptr, bytecode_dev_ptr); - } else { - runtime_jit->call("runtime_eval_adstack_max_reduce_serial", llvm_runtime_, - runtime_context_ptr_for_reducer, params_dev_ptr, - bytecode_dev_ptr); + // Grid-strided parallel reducer. Cap `grid_dim` at the launcher's saturating grid_dim so the dispatch shares + // the heap-row budget the rest of the launcher uses; `block_dim` matches the codegen default. The grid-stride + // loop handles arbitrary `total_length` so under-dispatching is harmless. + std::uint64_t cross_product = 1; + for (std::size_t a = 0; a < pending[k].params.num_axes; ++a) { + cross_product *= static_cast(pending[k].params.per_axis_length[a]); } + const std::size_t block_dim = static_cast(std::max(1, config_.max_block_dim)); + const std::size_t needed_threads = std::max(1, static_cast(cross_product)); + const std::size_t grid_dim_cap = static_cast(std::max(1, config_.saturating_grid_dim)); + const std::size_t grid_dim = std::min(grid_dim_cap, (needed_threads + block_dim - 1) / block_dim); + runtime_jit->launch("runtime_eval_adstack_max_reduce", grid_dim, block_dim, 0, + llvm_runtime_, runtime_context_ptr_for_reducer, + params_dev_ptr, bytecode_dev_ptr); } // Read back this round's output slots. The runtime function writes int64 values at `outputs[output_slot]`; each diff --git a/quadrants/runtime/llvm/llvm_runtime_executor.h b/quadrants/runtime/llvm/llvm_runtime_executor.h index b642bb232d..a0a61d8d5d 100644 --- a/quadrants/runtime/llvm/llvm_runtime_executor.h +++ b/quadrants/runtime/llvm/llvm_runtime_executor.h @@ -185,13 +185,14 @@ class LlvmRuntimeExecutor { // Max-reducer dispatch on LLVM. For each captured `StaticAdStackMaxReducerSpec` across every task in `tasks`, hits // `AdStackCache::try_max_reducer_cache_hit` first; on miss h2d-copies the params blob + body bytecode and invokes - // `runtime_eval_adstack_max_reduce` via the runtime JIT. Single dispatch path covers CPU (host call), CUDA, and - // AMDGPU. The returned map is keyed by `(registry_id, stack_id, mor_node_idx)` packed via the same encoding the gfx - // variant uses, so `substitute_precomputed_max_over_range` works backend-agnostically. Caller invokes this BEFORE the - // per-task `publish_adstack_metadata` loop and passes the result map down to each per-task `publish` call so the - // encoder substitutes captured `MaxOverRange`s before walking the tree. `MaxReducerResultMap` is defined in - // `quadrants/program/adstack_size_expr_eval.h`; declared inline here to avoid pulling that header into every - // translation unit that includes `llvm_runtime_executor.h`. + // `runtime_eval_adstack_max_reduce` via the runtime JIT as a grid-strided launch with an `atomic_max_i64` + // reduction. CUDA and AMDGPU only; on CPU the recognizer is skipped at codegen time so this path runs zero + // dispatches. The returned map is keyed by `(registry_id, stack_id, mor_node_idx)` packed via the same encoding the + // gfx variant uses, so `substitute_precomputed_max_over_range` works backend-agnostically. Caller invokes this + // BEFORE the per-task `publish_adstack_metadata` loop and passes the result map down to each per-task `publish` + // call so the encoder substitutes captured `MaxOverRange`s before walking the tree. `MaxReducerResultMap` is + // defined in `quadrants/program/adstack_size_expr_eval.h`; declared inline here to avoid pulling that header into + // every translation unit that includes `llvm_runtime_executor.h`. std::unordered_map dispatch_max_reducers_for_tasks(const std::vector &ad_stacks, LaunchContextBuilder *ctx, void *device_runtime_context_ptr); diff --git a/quadrants/runtime/llvm/runtime_module/adstack_runtime.cpp b/quadrants/runtime/llvm/runtime_module/adstack_runtime.cpp index d96e2c088d..9377a63278 100644 --- a/quadrants/runtime/llvm/runtime_module/adstack_runtime.cpp +++ b/quadrants/runtime/llvm/runtime_module/adstack_runtime.cpp @@ -592,59 +592,14 @@ __attribute__((always_inline)) inline i64 runtime_eval_adstack_max_reduce_one_st } } // namespace -extern "C" void runtime_eval_adstack_max_reduce_serial(LLVMRuntime *runtime, - RuntimeContext *ctx, - Ptr params_blob, - Ptr body_bytecode) { - using quadrants::lang::AdStackSizeExprDeviceNode; - using quadrants::lang::kAdStackMaxReducerMaxAxes; - using quadrants::lang::LlvmAdStackMaxReducerDeviceParams; - - const auto *params = reinterpret_cast(params_blob); - const auto *nodes = reinterpret_cast(body_bytecode); - const auto *indices = reinterpret_cast(reinterpret_cast(nodes) + - sizeof(AdStackSizeExprDeviceNode) * params->body_node_count); - - const char *arg_buffer = ctx->arg_buffer; - DeviceEvalScope scope; - for (i32 k = 0; k < kDeviceBoundVarCap; ++k) { - scope.values[k] = 0; - } - - // Sentinel start: INT64_MIN so the first body value always wins over an empty cross-product. Caller normalises the - // empty case (writes 0 / floors at compile-time) when reading the slot back. - i64 running_max = (i64)0x8000000000000000ll; - const u32 num_axes = params->num_axes; - if (num_axes == 0 || num_axes > (u32)kAdStackMaxReducerMaxAxes) { - runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; - return; - } - u64 total_length = 1; - for (u32 a = 0; a < num_axes; ++a) { - if (params->per_axis_length[a] == 0u) { - runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; - return; - } - total_length *= (u64)params->per_axis_length[a]; - } - for (u64 i = 0; i < total_length; ++i) { - i64 v = runtime_eval_adstack_max_reduce_one_step(runtime, nodes, indices, params->body_node_count, params, i, - num_axes, &scope, arg_buffer); - if (v > running_max) { - running_max = v; - } - } - runtime->adstack_max_reducer_outputs[params->output_slot] = running_max; -} - #if ARCH_cuda || ARCH_amdgpu // Forward decl: defined later in runtime.cpp; included here in single-TU layout means we forward-decl rather than // reorder the function definitions. void block_barrier(); -extern "C" void runtime_eval_adstack_max_reduce_parallel(LLVMRuntime *runtime, - RuntimeContext *ctx, - Ptr params_blob, - Ptr body_bytecode) { +extern "C" void runtime_eval_adstack_max_reduce(LLVMRuntime *runtime, + RuntimeContext *ctx, + Ptr params_blob, + Ptr body_bytecode) { using quadrants::lang::AdStackSizeExprDeviceNode; using quadrants::lang::kAdStackMaxReducerMaxAxes; using quadrants::lang::LlvmAdStackMaxReducerDeviceParams; diff --git a/quadrants/runtime/llvm/runtime_module/adstack_runtime.h b/quadrants/runtime/llvm/runtime_module/adstack_runtime.h index 17480df113..82ecba2bf1 100644 --- a/quadrants/runtime/llvm/runtime_module/adstack_runtime.h +++ b/quadrants/runtime/llvm/runtime_module/adstack_runtime.h @@ -35,21 +35,15 @@ void runtime_get_adstack_max_reducer_field_ptr(LLVMRuntime *runtime); // Per-launch device-resident reducers / interpreters consumed by the host launcher right before each adstack-bearing // kernel dispatch. `runtime_eval_static_bound_count` walks a captured gating ndarray / SNode field and writes the // gate-passing count into `runtime->adstack_bound_row_capacities[task_index]`. -// `runtime_eval_adstack_max_reduce_{serial,parallel}` walk a captured `StaticAdStackMaxReducerSpec` body over its -// multi-axis cross-product and reduce-max into `runtime->adstack_max_reducer_outputs[output_slot]`. The `_serial` -// variant is a single-thread call (used on CPU); the `_parallel` variant is a grid-strided launch with an -// `atomic_max_i64` reduction (used on CUDA / AMDGPU). `runtime_eval_adstack_size_expr` walks every alloca's SizeExpr -// tree and publishes per-stack offsets / max_sizes plus the per-thread strides into `LLVMRuntime`. The blob layouts -// are defined in the `quadrants/ir/...adstack...device.h` headers. +// `runtime_eval_adstack_max_reduce` walks a captured `StaticAdStackMaxReducerSpec` body over its multi-axis +// cross-product and reduce-maxes into `runtime->adstack_max_reducer_outputs[output_slot]` via a grid-strided launch +// with an `atomic_max_i64` reduction. CPU is not supported: the recognizer is skipped on CPU (see +// `codegen/llvm/codegen_llvm.cpp::finalize_offloaded_task_function`) so this entry point is only reached on CUDA / +// AMDGPU. `runtime_eval_adstack_size_expr` walks every alloca's SizeExpr tree and publishes per-stack offsets / +// max_sizes plus the per-thread strides into `LLVMRuntime`. The blob layouts are defined in the +// `quadrants/ir/...adstack...device.h` headers. void runtime_eval_static_bound_count(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr params_blob); -void runtime_eval_adstack_max_reduce_serial(LLVMRuntime *runtime, - RuntimeContext *ctx, - Ptr params_blob, - Ptr body_bytecode); -void runtime_eval_adstack_max_reduce_parallel(LLVMRuntime *runtime, - RuntimeContext *ctx, - Ptr params_blob, - Ptr body_bytecode); +void runtime_eval_adstack_max_reduce(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr params_blob, Ptr body_bytecode); void runtime_eval_adstack_size_expr(LLVMRuntime *runtime, RuntimeContext *ctx, Ptr bytecode); // Publish the device-mapped addresses of the pinned-host overflow flag / task-id slots the host allocated at