Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 47 additions & 90 deletions src/scheduler/src/collective.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,52 +157,48 @@ impl AllReduce for Topology<MockSail> {
// honours the per-card-contiguous buffer contract that
// the real-device DMA path (PR #6b) will rely on.
//
// Initialise `reduced` from card 0, then fold cards
// 1..n into it according to a per-op combinator chosen
// once up-front.
let n_cards = per_card.len();
let mut reduced = per_card[0].clone();
let rest = per_card.iter().skip(1);

// Bit-exact-identical to the previous per-element match:
// Sum and Avg both accumulate into reduced (Avg divides
// at the end), Max/Min do an in-place fold.
match op {
ReduceOp::Sum | ReduceOp::Avg => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
*r += *x;
}
// Pick the per-op combinator once up-front as a `fn`
// pointer, then run a single generic outer-card /
// inner-element fold. This both eliminates the previous
// `let rest = ...` fragility (a single `iter().skip(1)`
// was consumed inside one match arm; future variants
// would have surfaced a "value used after move" error)
// AND keeps the dispatch a single decision per call.
// NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate
// matches are still exhaustive — when adding a new
// variant (Product, LogSumExp, BitwiseOr…), extend
// this `match` instead of falling back to a `_` arm
// (which would silently drop unimplemented ops).
let combine: fn(&mut f32, f32) = match op {
ReduceOp::Sum | ReduceOp::Avg => |acc, x| *acc += x,
ReduceOp::Max => |acc, x| {
if x > *acc {
*acc = x;
}
if matches!(op, ReduceOp::Avg) {
let n = n_cards as f32;
for r in reduced.iter_mut() {
*r /= n;
}
},
ReduceOp::Min => |acc, x| {
if x < *acc {
*acc = x;
}
},
};

// Bit-exact-identical to the prior per-op-arm impl:
// initialise `reduced` from card 0, then fold cards
// 1..n into it. Avg divides after the fold so the
// accumulation order matches the previous Sum path.
let n_cards = per_card.len();
let mut reduced = per_card[0].clone();
for v in per_card.iter().skip(1) {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
combine(r, *x);
}
ReduceOp::Max => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
if *x > *r {
*r = *x;
}
}
}
}
if matches!(op, ReduceOp::Avg) {
let n = n_cards as f32;
for r in reduced.iter_mut() {
*r /= n;
}
ReduceOp::Min => {
for v in rest {
for (r, x) in reduced.iter_mut().zip(v.iter()) {
if *x < *r {
*r = *x;
}
}
}
} // NB: `ReduceOp` is `#[non_exhaustive]`, but in-crate
// matches are still exhaustive — when adding a new
// variant (Product, LogSumExp, BitwiseOr…), extend
// this `match` instead of falling back to a `_` arm
// (which would silently drop unimplemented ops).
}

for v in per_card.iter_mut() {
Expand Down Expand Up @@ -268,54 +264,15 @@ mod tests {
));
}

/// Regression sentinel for the `#[non_exhaustive]` attribute
/// on `ReduceOp`. The `compile_fail` doctest on the type is
/// the load-bearing test (it actually runs `rustc` and
/// asserts compilation fails). This unit test is a
/// human-readable sentinel: if it ever fails to *compile*,
/// someone removed the `_ =>` catch-all from a downstream-
/// shaped match (we simulate "downstream" by including a
/// `_ =>` arm here, which the compiler accepts as redundant
/// in-crate but which is *required* downstream — see
/// `compile_fail` doctest).
#[test]
#[allow(unreachable_patterns)]
fn reduce_op_non_exhaustive_requires_catchall_downstream() {
// In-crate: this match without `_ =>` would compile,
// because `#[non_exhaustive]` only restricts downstream
// crates. Including a `_ =>` arm here mirrors what every
// downstream consumer MUST write.
fn classify(op: ReduceOp) -> &'static str {
match op {
ReduceOp::Sum => "sum",
ReduceOp::Max => "max",
ReduceOp::Min => "min",
ReduceOp::Avg => "avg",
_ => "unknown (future variant)",
}
}
assert_eq!(classify(ReduceOp::Sum), "sum");
assert_eq!(classify(ReduceOp::Avg), "avg");
}

/// Same shape as the `ReduceOp` regression sentinel above,
/// but for `crate::Error`. See the `compile_fail` doctest on
/// the `Error` enum for the load-bearing assertion.
#[test]
#[allow(unreachable_patterns)]
fn error_non_exhaustive_requires_catchall_downstream() {
fn label(e: &Error) -> &'static str {
match e {
Error::NoSails => "no sails",
Error::TopologyMismatch { .. } => "topology",
Error::ShapeMismatch { .. } => "shape",
Error::NotImplemented(_) => "not impl",
Error::Runtime(_) => "runtime",
_ => "unknown (future variant)",
}
}
assert_eq!(label(&Error::NoSails), "no sails");
}
// NB: the `*_non_exhaustive_requires_catchall_downstream`
// unit tests previously lived here. They were removed in
// PR #11 because `#[allow(unreachable_patterns)]` made them
// pass even if `#[non_exhaustive]` was removed from the
// enum, so they did not actually guard the attribute. The
// load-bearing guards are the `compile_fail` doctests on
// [`ReduceOp`] and [`Error`] (in `lib.rs`), which run
// `rustc` and assert downstream-shape matches without a
// `_ =>` arm fail to compile.

#[test]
fn all_reduce_sum_shape_mismatch() {
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler/src/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl Topology<MockSail> {
// n*(n-1) directed edges in a fully-meshed graph
// (each of n nodes has a directed link to each of the
// n-1 other nodes).
let mut links = Vec::with_capacity(n_sails.saturating_sub(1) * n_sails);
let mut links = Vec::with_capacity(n_sails * n_sails.saturating_sub(1));
for local in 0..n_sails {
for remote in 0..n_sails {
if local == remote {
Expand Down
Loading