diff --git a/src/scheduler/src/collective.rs b/src/scheduler/src/collective.rs index 5212f6e..ad37d47 100644 --- a/src/scheduler/src/collective.rs +++ b/src/scheduler/src/collective.rs @@ -20,7 +20,34 @@ use crate::topology::{MockSail, Topology}; use crate::{Error, Result}; /// Reduction operation for [`AllReduce`]. +/// +/// Marked `#[non_exhaustive]` so future variants +/// (e.g. `Product`, `LogSumExp`, `BitwiseOr`) can be added +/// without a major-version semver bump. Downstream `match` arms +/// must include a `_ =>` catch-all. +/// +/// # Regression guard +/// +/// The following doctest fails to compile *because* `ReduceOp` +/// is `#[non_exhaustive]`: a downstream `match` that names every +/// current variant is rejected without a `_ =>` arm. If someone +/// removes the `#[non_exhaustive]` attribute the doctest will +/// start to compile, the `compile_fail` will fail, and CI will +/// catch the silent semver-evolution regression. +/// +/// ```compile_fail +/// use spanker_scheduler::ReduceOp; +/// fn name(op: ReduceOp) -> &'static str { +/// match op { +/// ReduceOp::Sum => "sum", +/// ReduceOp::Max => "max", +/// ReduceOp::Min => "min", +/// ReduceOp::Avg => "avg", +/// } +/// } +/// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] pub enum ReduceOp { /// Element-wise sum across cards. Sum, @@ -35,6 +62,17 @@ pub enum ReduceOp { /// Reduce a per-card array across all sails using `op`. After /// the call every per-card buffer contains the same reduced /// result. +/// +/// # Buffer-layout contract +/// +/// Callers MUST provide one contiguous `Vec` per card, +/// indexed in topology order. Implementations are free to +/// stream each card's buffer sequentially before moving to the +/// next, so per-card contiguity is the public layout contract +/// for both the host-side mock and the real-device path. Do NOT +/// pass interleaved or strided slices: the real-device DMA path +/// (PR #6b) will assume per-card contiguity to map each `Vec` +/// to a single scatter-gather descriptor. pub trait AllReduce { /// Reduce `per_card[i]` across all `i` using `op`. All /// buffers must have the same length; mismatched shapes @@ -109,42 +147,62 @@ impl AllReduce for Topology { return Ok(()); } - let n = per_card.len() as f32; - let mut reduced = vec![0.0f32; stride]; - for i in 0..stride { - let initial = per_card[0][i]; - let acc = match op { - ReduceOp::Sum | ReduceOp::Avg => { - let mut s = 0.0f32; - for v in per_card.iter() { - s += v[i]; + // Cache-friendly layout: hoist the `match op` outside the + // per-element loop and transpose the loops so each card's + // contiguous Vec is read sequentially (outer = card, + // inner = element). For n_sails=4, stride=1M this turns + // a 4-way Vec stride per element into a single linear + // sweep per card — at least one order-of-magnitude fewer + // L1 misses on the host-side mock and (more importantly) + // honours the per-card-contiguous buffer contract that + // the real-device DMA path (PR #6b) will rely on. + // + // Initialise `reduced` from card 0, then fold cards + // 1..n into it according to a per-op combinator chosen + // once up-front. + let n_cards = per_card.len(); + let mut reduced = per_card[0].clone(); + let rest = per_card.iter().skip(1); + + // Bit-exact-identical to the previous per-element match: + // Sum and Avg both accumulate into reduced (Avg divides + // at the end), Max/Min do an in-place fold. + match op { + ReduceOp::Sum | ReduceOp::Avg => { + for v in rest { + for (r, x) in reduced.iter_mut().zip(v.iter()) { + *r += *x; } - if matches!(op, ReduceOp::Avg) { - s / n - } else { - s + } + if matches!(op, ReduceOp::Avg) { + let n = n_cards as f32; + for r in reduced.iter_mut() { + *r /= n; } } - ReduceOp::Max => { - let mut m = initial; - for v in per_card.iter().skip(1) { - if v[i] > m { - m = v[i]; + } + ReduceOp::Max => { + for v in rest { + for (r, x) in reduced.iter_mut().zip(v.iter()) { + if *x > *r { + *r = *x; } } - m } - ReduceOp::Min => { - let mut m = initial; - for v in per_card.iter().skip(1) { - if v[i] < m { - m = v[i]; + } + ReduceOp::Min => { + for v in rest { + for (r, x) in reduced.iter_mut().zip(v.iter()) { + if *x < *r { + *r = *x; } } - m } - }; - reduced[i] = acc; + } // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate + // matches are still exhaustive — when adding a new + // variant (Product, LogSumExp, BitwiseOr…), extend + // this `match` instead of falling back to a `_` arm + // (which would silently drop unimplemented ops). } for v in per_card.iter_mut() { @@ -210,6 +268,55 @@ mod tests { )); } + /// Regression sentinel for the `#[non_exhaustive]` attribute + /// on `ReduceOp`. The `compile_fail` doctest on the type is + /// the load-bearing test (it actually runs `rustc` and + /// asserts compilation fails). This unit test is a + /// human-readable sentinel: if it ever fails to *compile*, + /// someone removed the `_ =>` catch-all from a downstream- + /// shaped match (we simulate "downstream" by including a + /// `_ =>` arm here, which the compiler accepts as redundant + /// in-crate but which is *required* downstream — see + /// `compile_fail` doctest). + #[test] + #[allow(unreachable_patterns)] + fn reduce_op_non_exhaustive_requires_catchall_downstream() { + // In-crate: this match without `_ =>` would compile, + // because `#[non_exhaustive]` only restricts downstream + // crates. Including a `_ =>` arm here mirrors what every + // downstream consumer MUST write. + fn classify(op: ReduceOp) -> &'static str { + match op { + ReduceOp::Sum => "sum", + ReduceOp::Max => "max", + ReduceOp::Min => "min", + ReduceOp::Avg => "avg", + _ => "unknown (future variant)", + } + } + assert_eq!(classify(ReduceOp::Sum), "sum"); + assert_eq!(classify(ReduceOp::Avg), "avg"); + } + + /// Same shape as the `ReduceOp` regression sentinel above, + /// but for `crate::Error`. See the `compile_fail` doctest on + /// the `Error` enum for the load-bearing assertion. + #[test] + #[allow(unreachable_patterns)] + fn error_non_exhaustive_requires_catchall_downstream() { + fn label(e: &Error) -> &'static str { + match e { + Error::NoSails => "no sails", + Error::TopologyMismatch { .. } => "topology", + Error::ShapeMismatch { .. } => "shape", + Error::NotImplemented(_) => "not impl", + Error::Runtime(_) => "runtime", + _ => "unknown (future variant)", + } + } + assert_eq!(label(&Error::NoSails), "no sails"); + } + #[test] fn all_reduce_sum_shape_mismatch() { let t = Topology::::with_mock(2); diff --git a/src/scheduler/src/intercard.rs b/src/scheduler/src/intercard.rs index 82f3514..353115b 100644 --- a/src/scheduler/src/intercard.rs +++ b/src/scheduler/src/intercard.rs @@ -25,7 +25,13 @@ pub const INTERCARD_BUS_WIDTH: usize = 128; /// State of a single inter-card link, mirroring `link_state_t` /// in MAST #14. +/// +/// Marked `#[non_exhaustive]` because the inter-card protocol is +/// still TBD per ADR-014; new states (e.g. `Quiesced`, +/// `Recalibrating`) may land without a major-version semver bump. +/// Downstream `match` arms must include a `_ =>` catch-all. #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] pub enum LinkState { /// Link is down; no traffic. Down, @@ -42,7 +48,14 @@ pub enum LinkState { /// `local_sail` and `remote_sail` are indices into /// [`crate::Topology::sails`]; the protocol that flows over the /// link is opaque to this crate and lands in ADR-014. +/// +/// Marked `#[non_exhaustive]` because ADR-014 will add fields +/// such as `bandwidth_gbps` and `latency_ns`; downstream crates +/// must construct `Link` via a constructor (e.g. `Link::new`) +/// rather than the struct literal so we can grow the struct +/// without a major-version semver bump. #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] pub struct Link { /// Topology index of the originating sail. pub local_sail: usize, @@ -52,6 +65,23 @@ pub struct Link { pub state: LinkState, } +impl Link { + /// Construct a new `Link`. + /// + /// Required because `Link` is `#[non_exhaustive]`, which prevents + /// downstream crates from constructing it via struct-literal syntax. + /// All future fields added to `Link` should remain optional via + /// further `with_*` builder methods or by extending this constructor + /// signature with a new minor version bump. + pub fn new(local_sail: usize, remote_sail: usize, state: LinkState) -> Self { + Self { + local_sail, + remote_sail, + state, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/scheduler/src/lib.rs b/src/scheduler/src/lib.rs index 06314ad..f53a01a 100644 --- a/src/scheduler/src/lib.rs +++ b/src/scheduler/src/lib.rs @@ -39,7 +39,35 @@ pub use intercard::{Link, LinkState, INTERCARD_BUS_WIDTH, INTERCARD_LANES, INTER pub use topology::{MockSail, Topology}; /// Errors returned by this crate. +/// +/// Marked `#[non_exhaustive]` because library `Error` enums +/// almost always grow new variants; downstream `match` arms +/// must include a `_ =>` catch-all so we can extend without a +/// major-version semver bump. +/// +/// # Regression guard +/// +/// The following doctest fails to compile *because* `Error` is +/// `#[non_exhaustive]`: a downstream exhaustive `match` is +/// rejected without a `_ =>` arm. If someone removes the +/// `#[non_exhaustive]` attribute the doctest will start to +/// compile, the `compile_fail` will fail, and CI will catch the +/// silent semver-evolution regression. +/// +/// ```compile_fail +/// use spanker_scheduler::Error; +/// fn classify(e: Error) -> &'static str { +/// match e { +/// Error::NoSails => "no sails", +/// Error::TopologyMismatch { .. } => "topology", +/// Error::ShapeMismatch { .. } => "shape", +/// Error::NotImplemented(_) => "not impl", +/// Error::Runtime(_) => "runtime", +/// } +/// } +/// ``` #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum Error { /// `Topology::enumerate()` found no `/dev/spanker*` device /// nodes (typical cause: `spanker.ko` is not loaded, or the diff --git a/src/scheduler/src/topology.rs b/src/scheduler/src/topology.rs index 443e7b7..524334f 100644 --- a/src/scheduler/src/topology.rs +++ b/src/scheduler/src/topology.rs @@ -81,6 +81,9 @@ impl Topology { /// cards, all inter-card links in [`LinkState::Up`]. pub fn with_mock(n_sails: usize) -> Self { let sails = (0..n_sails).map(MockSail::new).collect(); + // n*(n-1) directed edges in a fully-meshed graph + // (each of n nodes has a directed link to each of the + // n-1 other nodes). let mut links = Vec::with_capacity(n_sails.saturating_sub(1) * n_sails); for local in 0..n_sails { for remote in 0..n_sails {