From c9b0394d846320ef4131b66695fb869f53fc6205 Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 6 May 2026 02:15:14 -0300 Subject: [PATCH] chore(scheduler): closure-based all_reduce dispatch + sentinel cleanup + comment reorder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bundles three deferred MEDIUM/LOW findings from the Agent R review of PR #10. None were merge-blocking; this is a focused cleanup. 1. Sentinel unit tests deleted (Option A). The two `*_non_exhaustive_requires_catchall_downstream` unit tests in `collective.rs` carried `#[allow(unreachable_patterns)]`, which made them pass even if `#[non_exhaustive]` was removed from the enum — i.e. they never actually guarded the attribute. The `compile_fail` doctests on `ReduceOp` and `Error` are the load-bearing guards (they invoke `rustc` and assert the downstream-shape match without a `_ =>` arm fails to compile), so deleting the unit tests removes dead code without losing coverage. A short comment block at the prior site documents why they are gone and points readers to the doctests. 2. Closure-based all_reduce dispatch. Replaces the per-arm `match op { Sum|Avg => ..., Max => ..., Min => ... }` — which consumed `let rest = per_card.iter() .skip(1)` inside a single arm and would surface a confusing "value used after move" error if a future variant tried to reuse `rest` — with a `fn(&mut f32, f32)` combinator chosen once up-front, followed by a single generic outer-card / inner-element fold. Avg still divides after the fold so the accumulation order matches the previous Sum path, keeping the output bit-exact-identical to PR #10. Verified by re-running all 9 `tests/topology_mock.rs` cases (Sum/Avg/Max/Min) — all pass with no numerical drift. 3. Topology capacity expression reordered. `Vec::with_capacity(n_sails.saturating_sub(1) * n_sails)` is now `Vec::with_capacity(n_sails * n_sails.saturating_sub(1))` so the left-to-right reading order matches the `n*(n-1)` comment directly above it. Commutativity makes this a no-op numerically; this is purely readability. Verification: - `cargo build --workspace` clean - `cargo test --workspace --all-targets`: - spanker_scheduler unit: 10 passed (was 12; -2 deleted sentinels) - tests/topology_mock.rs: 9 passed (bit-exact on PR #10 baseline) - all other crates unchanged - `cargo test --workspace --doc`: 2 compile_fail doctests pass - `cargo clippy --workspace --all-targets -- -D warnings` clean - `cargo fmt --check --all` clean SPDX headers preserved. Crate version stays 0.1.0. Closes #11. Authored by Agent 3 (Software Stack — Spanker). Signed-off-by: Marcos --- src/scheduler/src/collective.rs | 137 +++++++++++--------------------- src/scheduler/src/topology.rs | 2 +- 2 files changed, 48 insertions(+), 91 deletions(-) diff --git a/src/scheduler/src/collective.rs b/src/scheduler/src/collective.rs index ad37d47..f86ca85 100644 --- a/src/scheduler/src/collective.rs +++ b/src/scheduler/src/collective.rs @@ -157,52 +157,48 @@ impl AllReduce for Topology { // honours the per-card-contiguous buffer contract that // the real-device DMA path (PR #6b) will rely on. // - // Initialise `reduced` from card 0, then fold cards - // 1..n into it according to a per-op combinator chosen - // once up-front. - let n_cards = per_card.len(); - let mut reduced = per_card[0].clone(); - let rest = per_card.iter().skip(1); - - // Bit-exact-identical to the previous per-element match: - // Sum and Avg both accumulate into reduced (Avg divides - // at the end), Max/Min do an in-place fold. - match op { - ReduceOp::Sum | ReduceOp::Avg => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - *r += *x; - } + // Pick the per-op combinator once up-front as a `fn` + // pointer, then run a single generic outer-card / + // inner-element fold. This both eliminates the previous + // `let rest = ...` fragility (a single `iter().skip(1)` + // was consumed inside one match arm; future variants + // would have surfaced a "value used after move" error) + // AND keeps the dispatch a single decision per call. + // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate + // matches are still exhaustive — when adding a new + // variant (Product, LogSumExp, BitwiseOr…), extend + // this `match` instead of falling back to a `_` arm + // (which would silently drop unimplemented ops). + let combine: fn(&mut f32, f32) = match op { + ReduceOp::Sum | ReduceOp::Avg => |acc, x| *acc += x, + ReduceOp::Max => |acc, x| { + if x > *acc { + *acc = x; } - if matches!(op, ReduceOp::Avg) { - let n = n_cards as f32; - for r in reduced.iter_mut() { - *r /= n; - } + }, + ReduceOp::Min => |acc, x| { + if x < *acc { + *acc = x; } + }, + }; + + // Bit-exact-identical to the prior per-op-arm impl: + // initialise `reduced` from card 0, then fold cards + // 1..n into it. Avg divides after the fold so the + // accumulation order matches the previous Sum path. + let n_cards = per_card.len(); + let mut reduced = per_card[0].clone(); + for v in per_card.iter().skip(1) { + for (r, x) in reduced.iter_mut().zip(v.iter()) { + combine(r, *x); } - ReduceOp::Max => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - if *x > *r { - *r = *x; - } - } - } + } + if matches!(op, ReduceOp::Avg) { + let n = n_cards as f32; + for r in reduced.iter_mut() { + *r /= n; } - ReduceOp::Min => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - if *x < *r { - *r = *x; - } - } - } - } // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate - // matches are still exhaustive — when adding a new - // variant (Product, LogSumExp, BitwiseOr…), extend - // this `match` instead of falling back to a `_` arm - // (which would silently drop unimplemented ops). } for v in per_card.iter_mut() { @@ -268,54 +264,15 @@ mod tests { )); } - /// Regression sentinel for the `#[non_exhaustive]` attribute - /// on `ReduceOp`. The `compile_fail` doctest on the type is - /// the load-bearing test (it actually runs `rustc` and - /// asserts compilation fails). This unit test is a - /// human-readable sentinel: if it ever fails to *compile*, - /// someone removed the `_ =>` catch-all from a downstream- - /// shaped match (we simulate "downstream" by including a - /// `_ =>` arm here, which the compiler accepts as redundant - /// in-crate but which is *required* downstream — see - /// `compile_fail` doctest). - #[test] - #[allow(unreachable_patterns)] - fn reduce_op_non_exhaustive_requires_catchall_downstream() { - // In-crate: this match without `_ =>` would compile, - // because `#[non_exhaustive]` only restricts downstream - // crates. Including a `_ =>` arm here mirrors what every - // downstream consumer MUST write. - fn classify(op: ReduceOp) -> &'static str { - match op { - ReduceOp::Sum => "sum", - ReduceOp::Max => "max", - ReduceOp::Min => "min", - ReduceOp::Avg => "avg", - _ => "unknown (future variant)", - } - } - assert_eq!(classify(ReduceOp::Sum), "sum"); - assert_eq!(classify(ReduceOp::Avg), "avg"); - } - - /// Same shape as the `ReduceOp` regression sentinel above, - /// but for `crate::Error`. See the `compile_fail` doctest on - /// the `Error` enum for the load-bearing assertion. - #[test] - #[allow(unreachable_patterns)] - fn error_non_exhaustive_requires_catchall_downstream() { - fn label(e: &Error) -> &'static str { - match e { - Error::NoSails => "no sails", - Error::TopologyMismatch { .. } => "topology", - Error::ShapeMismatch { .. } => "shape", - Error::NotImplemented(_) => "not impl", - Error::Runtime(_) => "runtime", - _ => "unknown (future variant)", - } - } - assert_eq!(label(&Error::NoSails), "no sails"); - } + // NB: the `*_non_exhaustive_requires_catchall_downstream` + // unit tests previously lived here. They were removed in + // PR #11 because `#[allow(unreachable_patterns)]` made them + // pass even if `#[non_exhaustive]` was removed from the + // enum, so they did not actually guard the attribute. The + // load-bearing guards are the `compile_fail` doctests on + // [`ReduceOp`] and [`Error`] (in `lib.rs`), which run + // `rustc` and assert downstream-shape matches without a + // `_ =>` arm fail to compile. #[test] fn all_reduce_sum_shape_mismatch() { diff --git a/src/scheduler/src/topology.rs b/src/scheduler/src/topology.rs index 524334f..c19a10f 100644 --- a/src/scheduler/src/topology.rs +++ b/src/scheduler/src/topology.rs @@ -84,7 +84,7 @@ impl Topology { // n*(n-1) directed edges in a fully-meshed graph // (each of n nodes has a directed link to each of the // n-1 other nodes). - let mut links = Vec::with_capacity(n_sails.saturating_sub(1) * n_sails); + let mut links = Vec::with_capacity(n_sails * n_sails.saturating_sub(1)); for local in 0..n_sails { for remote in 0..n_sails { if local == remote {