Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 134 additions & 27 deletions src/scheduler/src/collective.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,34 @@ use crate::topology::{MockSail, Topology};
use crate::{Error, Result};

/// Reduction operation for [`AllReduce`].
///
/// Marked `#[non_exhaustive]` so future variants
/// (e.g. `Product`, `LogSumExp`, `BitwiseOr`) can be added
/// without a major-version semver bump. Downstream `match` arms
/// must include a `_ =>` catch-all.
///
/// # Regression guard
///
/// The following doctest fails to compile *because* `ReduceOp`
/// is `#[non_exhaustive]`: a downstream `match` that names every
/// current variant is rejected without a `_ =>` arm. If someone
/// removes the `#[non_exhaustive]` attribute the doctest will
/// start to compile, the `compile_fail` will fail, and CI will
/// catch the silent semver-evolution regression.
///
/// ```compile_fail
/// use spanker_scheduler::ReduceOp;
/// fn name(op: ReduceOp) -> &'static str {
/// match op {
/// ReduceOp::Sum => "sum",
/// ReduceOp::Max => "max",
/// ReduceOp::Min => "min",
/// ReduceOp::Avg => "avg",
/// }
/// }
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReduceOp {
/// Element-wise sum across cards.
Sum,
Expand All @@ -35,6 +62,17 @@ pub enum ReduceOp {
/// Reduce a per-card array across all sails using `op`. After
/// the call every per-card buffer contains the same reduced
/// result.
///
/// # Buffer-layout contract
///
/// Callers MUST provide one contiguous `Vec<f32>` per card,
/// indexed in topology order. Implementations are free to
/// stream each card's buffer sequentially before moving to the
/// next, so per-card contiguity is the public layout contract
/// for both the host-side mock and the real-device path. Do NOT
/// pass interleaved or strided slices: the real-device DMA path
/// (PR #6b) will assume per-card contiguity to map each `Vec`
/// to a single scatter-gather descriptor.
pub trait AllReduce {
/// Reduce `per_card[i]` across all `i` using `op`. All
/// buffers must have the same length; mismatched shapes
Expand Down Expand Up @@ -109,42 +147,62 @@ impl AllReduce for Topology<MockSail> {
return Ok(());
}

let n = per_card.len() as f32;
let mut reduced = vec![0.0f32; stride];
for i in 0..stride {
let initial = per_card[0][i];
let acc = match op {
ReduceOp::Sum | ReduceOp::Avg => {
let mut s = 0.0f32;
for v in per_card.iter() {
s += v[i];
// Cache-friendly layout: hoist the `match op` outside the
// per-element loop and transpose the loops so each card's
// contiguous Vec is read sequentially (outer = card,
// inner = element). For n_sails=4, stride=1M this turns
// a 4-way Vec stride per element into a single linear
// sweep per card — at least one order-of-magnitude fewer
// L1 misses on the host-side mock and (more importantly)
// honours the per-card-contiguous buffer contract that
// the real-device DMA path (PR #6b) will rely on.
//
// Initialise `reduced` from card 0, then fold cards
// 1..n into it according to a per-op combinator chosen
// once up-front.
let n_cards = per_card.len();
let mut reduced = per_card[0].clone();
let rest = per_card.iter().skip(1);

// Bit-exact-identical to the previous per-element match:
// Sum and Avg both accumulate into reduced (Avg divides
// at the end), Max/Min do an in-place fold.
match op {
ReduceOp::Sum | ReduceOp::Avg => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
*r += *x;
}
if matches!(op, ReduceOp::Avg) {
s / n
} else {
s
}
if matches!(op, ReduceOp::Avg) {
let n = n_cards as f32;
for r in reduced.iter_mut() {
*r /= n;
}
}
ReduceOp::Max => {
let mut m = initial;
for v in per_card.iter().skip(1) {
if v[i] > m {
m = v[i];
}
ReduceOp::Max => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
if *x > *r {
*r = *x;
}
}
m
}
ReduceOp::Min => {
let mut m = initial;
for v in per_card.iter().skip(1) {
if v[i] < m {
m = v[i];
}
ReduceOp::Min => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
if *x < *r {
*r = *x;
}
}
m
}
};
reduced[i] = acc;
} // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate
// matches are still exhaustive — when adding a new
// variant (Product, LogSumExp, BitwiseOr…), extend
// this `match` instead of falling back to a `_` arm
// (which would silently drop unimplemented ops).
}

for v in per_card.iter_mut() {
Expand Down Expand Up @@ -210,6 +268,55 @@ mod tests {
));
}

/// Regression sentinel for the `#[non_exhaustive]` attribute
/// on `ReduceOp`. The `compile_fail` doctest on the type is
/// the load-bearing test (it actually runs `rustc` and
/// asserts compilation fails). This unit test is a
/// human-readable sentinel: if it ever fails to *compile*,
/// someone removed the `_ =>` catch-all from a downstream-
/// shaped match (we simulate "downstream" by including a
/// `_ =>` arm here, which the compiler accepts as redundant
/// in-crate but which is *required* downstream — see
/// `compile_fail` doctest).
#[test]
#[allow(unreachable_patterns)]
fn reduce_op_non_exhaustive_requires_catchall_downstream() {
// In-crate: this match without `_ =>` would compile,
// because `#[non_exhaustive]` only restricts downstream
// crates. Including a `_ =>` arm here mirrors what every
// downstream consumer MUST write.
fn classify(op: ReduceOp) -> &'static str {
match op {
ReduceOp::Sum => "sum",
ReduceOp::Max => "max",
ReduceOp::Min => "min",
ReduceOp::Avg => "avg",
_ => "unknown (future variant)",
}
}
assert_eq!(classify(ReduceOp::Sum), "sum");
assert_eq!(classify(ReduceOp::Avg), "avg");
}

/// Same shape as the `ReduceOp` regression sentinel above,
/// but for `crate::Error`. See the `compile_fail` doctest on
/// the `Error` enum for the load-bearing assertion.
#[test]
#[allow(unreachable_patterns)]
fn error_non_exhaustive_requires_catchall_downstream() {
fn label(e: &Error) -> &'static str {
match e {
Error::NoSails => "no sails",
Error::TopologyMismatch { .. } => "topology",
Error::ShapeMismatch { .. } => "shape",
Error::NotImplemented(_) => "not impl",
Error::Runtime(_) => "runtime",
_ => "unknown (future variant)",
}
}
assert_eq!(label(&Error::NoSails), "no sails");
}

#[test]
fn all_reduce_sum_shape_mismatch() {
let t = Topology::<MockSail>::with_mock(2);
Expand Down
30 changes: 30 additions & 0 deletions src/scheduler/src/intercard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ pub const INTERCARD_BUS_WIDTH: usize = 128;

/// State of a single inter-card link, mirroring `link_state_t`
/// in MAST #14.
///
/// Marked `#[non_exhaustive]` because the inter-card protocol is
/// still TBD per ADR-014; new states (e.g. `Quiesced`,
/// `Recalibrating`) may land without a major-version semver bump.
/// Downstream `match` arms must include a `_ =>` catch-all.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum LinkState {
/// Link is down; no traffic.
Down,
Expand All @@ -42,7 +48,14 @@ pub enum LinkState {
/// `local_sail` and `remote_sail` are indices into
/// [`crate::Topology::sails`]; the protocol that flows over the
/// link is opaque to this crate and lands in ADR-014.
///
/// Marked `#[non_exhaustive]` because ADR-014 will add fields
/// such as `bandwidth_gbps` and `latency_ns`; downstream crates
/// must construct `Link` via a constructor (e.g. `Link::new`)
/// rather than the struct literal so we can grow the struct
/// without a major-version semver bump.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct Link {
/// Topology index of the originating sail.
pub local_sail: usize,
Expand All @@ -52,6 +65,23 @@ pub struct Link {
pub state: LinkState,
}

impl Link {
/// Construct a new `Link`.
///
/// Required because `Link` is `#[non_exhaustive]`, which prevents
/// downstream crates from constructing it via struct-literal syntax.
/// All future fields added to `Link` should remain optional via
/// further `with_*` builder methods or by extending this constructor
/// signature with a new minor version bump.
pub fn new(local_sail: usize, remote_sail: usize, state: LinkState) -> Self {
Self {
local_sail,
remote_sail,
state,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
28 changes: 28 additions & 0 deletions src/scheduler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,35 @@ pub use intercard::{Link, LinkState, INTERCARD_BUS_WIDTH, INTERCARD_LANES, INTER
pub use topology::{MockSail, Topology};

/// Errors returned by this crate.
///
/// Marked `#[non_exhaustive]` because library `Error` enums
/// almost always grow new variants; downstream `match` arms
/// must include a `_ =>` catch-all so we can extend without a
/// major-version semver bump.
///
/// # Regression guard
///
/// The following doctest fails to compile *because* `Error` is
/// `#[non_exhaustive]`: a downstream exhaustive `match` is
/// rejected without a `_ =>` arm. If someone removes the
/// `#[non_exhaustive]` attribute the doctest will start to
/// compile, the `compile_fail` will fail, and CI will catch the
/// silent semver-evolution regression.
///
/// ```compile_fail
/// use spanker_scheduler::Error;
/// fn classify(e: Error) -> &'static str {
/// match e {
/// Error::NoSails => "no sails",
/// Error::TopologyMismatch { .. } => "topology",
/// Error::ShapeMismatch { .. } => "shape",
/// Error::NotImplemented(_) => "not impl",
/// Error::Runtime(_) => "runtime",
/// }
/// }
/// ```
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
/// `Topology::enumerate()` found no `/dev/spanker*` device
/// nodes (typical cause: `spanker.ko` is not loaded, or the
Expand Down
3 changes: 3 additions & 0 deletions src/scheduler/src/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ impl Topology<MockSail> {
/// cards, all inter-card links in [`LinkState::Up`].
pub fn with_mock(n_sails: usize) -> Self {
let sails = (0..n_sails).map(MockSail::new).collect();
// n*(n-1) directed edges in a fully-meshed graph
// (each of n nodes has a directed link to each of the
// n-1 other nodes).
let mut links = Vec::with_capacity(n_sails.saturating_sub(1) * n_sails);
for local in 0..n_sails {
for remote in 0..n_sails {
Expand Down
Loading