diff --git a/src/scheduler/src/collective.rs b/src/scheduler/src/collective.rs index ad37d47..f86ca85 100644 --- a/src/scheduler/src/collective.rs +++ b/src/scheduler/src/collective.rs @@ -157,52 +157,48 @@ impl AllReduce for Topology { // honours the per-card-contiguous buffer contract that // the real-device DMA path (PR #6b) will rely on. // - // Initialise `reduced` from card 0, then fold cards - // 1..n into it according to a per-op combinator chosen - // once up-front. - let n_cards = per_card.len(); - let mut reduced = per_card[0].clone(); - let rest = per_card.iter().skip(1); - - // Bit-exact-identical to the previous per-element match: - // Sum and Avg both accumulate into reduced (Avg divides - // at the end), Max/Min do an in-place fold. - match op { - ReduceOp::Sum | ReduceOp::Avg => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - *r += *x; - } + // Pick the per-op combinator once up-front as a `fn` + // pointer, then run a single generic outer-card / + // inner-element fold. This both eliminates the previous + // `let rest = ...` fragility (a single `iter().skip(1)` + // was consumed inside one match arm; future variants + // would have surfaced a "value used after move" error) + // AND keeps the dispatch a single decision per call. + // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate + // matches are still exhaustive — when adding a new + // variant (Product, LogSumExp, BitwiseOr…), extend + // this `match` instead of falling back to a `_` arm + // (which would silently drop unimplemented ops). + let combine: fn(&mut f32, f32) = match op { + ReduceOp::Sum | ReduceOp::Avg => |acc, x| *acc += x, + ReduceOp::Max => |acc, x| { + if x > *acc { + *acc = x; } - if matches!(op, ReduceOp::Avg) { - let n = n_cards as f32; - for r in reduced.iter_mut() { - *r /= n; - } + }, + ReduceOp::Min => |acc, x| { + if x < *acc { + *acc = x; } + }, + }; + + // Bit-exact-identical to the prior per-op-arm impl: + // initialise `reduced` from card 0, then fold cards + // 1..n into it. Avg divides after the fold so the + // accumulation order matches the previous Sum path. + let n_cards = per_card.len(); + let mut reduced = per_card[0].clone(); + for v in per_card.iter().skip(1) { + for (r, x) in reduced.iter_mut().zip(v.iter()) { + combine(r, *x); } - ReduceOp::Max => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - if *x > *r { - *r = *x; - } - } - } + } + if matches!(op, ReduceOp::Avg) { + let n = n_cards as f32; + for r in reduced.iter_mut() { + *r /= n; } - ReduceOp::Min => { - for v in rest { - for (r, x) in reduced.iter_mut().zip(v.iter()) { - if *x < *r { - *r = *x; - } - } - } - } // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate - // matches are still exhaustive — when adding a new - // variant (Product, LogSumExp, BitwiseOr…), extend - // this `match` instead of falling back to a `_` arm - // (which would silently drop unimplemented ops). } for v in per_card.iter_mut() { @@ -268,54 +264,15 @@ mod tests { )); } - /// Regression sentinel for the `#[non_exhaustive]` attribute - /// on `ReduceOp`. The `compile_fail` doctest on the type is - /// the load-bearing test (it actually runs `rustc` and - /// asserts compilation fails). This unit test is a - /// human-readable sentinel: if it ever fails to *compile*, - /// someone removed the `_ =>` catch-all from a downstream- - /// shaped match (we simulate "downstream" by including a - /// `_ =>` arm here, which the compiler accepts as redundant - /// in-crate but which is *required* downstream — see - /// `compile_fail` doctest). - #[test] - #[allow(unreachable_patterns)] - fn reduce_op_non_exhaustive_requires_catchall_downstream() { - // In-crate: this match without `_ =>` would compile, - // because `#[non_exhaustive]` only restricts downstream - // crates. Including a `_ =>` arm here mirrors what every - // downstream consumer MUST write. - fn classify(op: ReduceOp) -> &'static str { - match op { - ReduceOp::Sum => "sum", - ReduceOp::Max => "max", - ReduceOp::Min => "min", - ReduceOp::Avg => "avg", - _ => "unknown (future variant)", - } - } - assert_eq!(classify(ReduceOp::Sum), "sum"); - assert_eq!(classify(ReduceOp::Avg), "avg"); - } - - /// Same shape as the `ReduceOp` regression sentinel above, - /// but for `crate::Error`. See the `compile_fail` doctest on - /// the `Error` enum for the load-bearing assertion. - #[test] - #[allow(unreachable_patterns)] - fn error_non_exhaustive_requires_catchall_downstream() { - fn label(e: &Error) -> &'static str { - match e { - Error::NoSails => "no sails", - Error::TopologyMismatch { .. } => "topology", - Error::ShapeMismatch { .. } => "shape", - Error::NotImplemented(_) => "not impl", - Error::Runtime(_) => "runtime", - _ => "unknown (future variant)", - } - } - assert_eq!(label(&Error::NoSails), "no sails"); - } + // NB: the `*_non_exhaustive_requires_catchall_downstream` + // unit tests previously lived here. They were removed in + // PR #11 because `#[allow(unreachable_patterns)]` made them + // pass even if `#[non_exhaustive]` was removed from the + // enum, so they did not actually guard the attribute. The + // load-bearing guards are the `compile_fail` doctests on + // [`ReduceOp`] and [`Error`] (in `lib.rs`), which run + // `rustc` and assert downstream-shape matches without a + // `_ =>` arm fail to compile. #[test] fn all_reduce_sum_shape_mismatch() { diff --git a/src/scheduler/src/topology.rs b/src/scheduler/src/topology.rs index 524334f..c19a10f 100644 --- a/src/scheduler/src/topology.rs +++ b/src/scheduler/src/topology.rs @@ -84,7 +84,7 @@ impl Topology { // n*(n-1) directed edges in a fully-meshed graph // (each of n nodes has a directed link to each of the // n-1 other nodes). - let mut links = Vec::with_capacity(n_sails.saturating_sub(1) * n_sails); + let mut links = Vec::with_capacity(n_sails * n_sails.saturating_sub(1)); for local in 0..n_sails { for remote in 0..n_sails { if local == remote {