From 0431dbf50caff838e9d82a0d973d87fe87c996f6 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Mon, 18 May 2026 15:35:54 -0700 Subject: [PATCH 1/5] arguments: add --enable-r3-correctness-check CLI flag When set, flips RoutingReplayManager.enable_check_replay_result = True so the per-step overlap check (replay_base.py:178-219) fires for every training step. Off by default because the check roughly doubles the cost of routing. Intended for the R3 regression E2E on LLM360/RL360, which runs a small GPU sbatch on M2 every time a submodule-pin bump PR opens. With this flag, miles will raise AssertionError("R3 mismatch tokens ...") if the overlap drops below MILES_TEST_R3_THRESHOLD (default 1e-2), giving the E2E a hard pass/fail signal. The R3 master switch (--use-rollout-routing-replay) is still required; this flag has no effect without it. --- miles/utils/arguments.py | 13 +++++++++++++ train_async.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 717c2b6980..b6d516716c 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1024,6 +1024,19 @@ def add_algo_arguments(parser): default=False, help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370", ) + parser.add_argument( + "--enable-r3-correctness-check", + action="store_true", + default=False, + help=( + "Run RoutingReplayManager's per-step overlap check that " + "recomputes the training-side topk on the same scores and " + "asserts overlap with the rollout indices. Roughly 2x routing " + "cost; off by default. Intended for the R3 regression E2E " + "(LLM360/RL360 scripts/r3-e2e/). No effect unless " + "--use-rollout-routing-replay is also set." + ), + ) parser.add_argument( "--use-opsm", action="store_true", diff --git a/train_async.py b/train_async.py index e9e05a4062..dd1e8f6a03 100644 --- a/train_async.py +++ b/train_async.py @@ -75,4 +75,7 @@ async def train(args): if __name__ == "__main__": args = parse_args() + if getattr(args, "enable_r3_correctness_check", False): + from miles.utils.replay_base import RoutingReplayManager + RoutingReplayManager.enable_check_replay_result = True asyncio.run(train(args)) From db437d22658c17745e43b789c0f4746e1eab44f3 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Mon, 18 May 2026 15:54:47 -0700 Subject: [PATCH 2/5] prod: apply black drift cleanup Six files on the prod base had black-non-compliant formatting that pre-commit on PR #25 flagged as failures. Applying `black==24.3.0` (matches .pre-commit-config.yaml) brings them in line so CI passes. Also fixes the single line in train_async.py from this PR that black wants (blank line after the import). No behavioral changes; pure whitespace + line breaks. --- miles/backends/training_utils/log_utils.py | 44 +++++++++---------- miles/backends/training_utils/loss.py | 4 +- miles/ray/rollout.py | 4 +- .../generate_utils/openai_endpoint_utils.py | 2 +- miles/rollout/session/linear_trajectory.py | 5 +-- miles/utils/replay_base.py | 4 +- train_async.py | 1 + 7 files changed, 30 insertions(+), 34 deletions(-) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 4096a4f8ac..a53f2afd54 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -22,33 +22,33 @@ # Maps bare metric names to their W&B top-level section(s). # Keys appearing in multiple sections (e.g. pg_loss) are emitted under each. _TRAIN_METRIC_GROUPS: dict[str, list[str]] = { - "ppo_kl": ["policy_shift"], - "ois": ["policy_shift"], - "pg_clipfrac": ["policy_shift"], - "pg_loss": ["policy_shift", "optimization"], - "log_probs": ["policy_shift"], # current policy (training forward pass) - "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) - "ref_kl": ["policy_shift"], + "ppo_kl": ["policy_shift"], + "ois": ["policy_shift"], + "pg_clipfrac": ["policy_shift"], + "pg_loss": ["policy_shift", "optimization"], + "log_probs": ["policy_shift"], # current policy (training forward pass) + "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) + "ref_kl": ["policy_shift"], "train_rollout_logprob_abs_diff": ["train_inference_mismatch"], - "train_rollout_logprob_diff": ["train_inference_mismatch"], - "tis": ["train_inference_mismatch"], - "tis_abs": ["train_inference_mismatch"], - "tis_clipfrac": ["train_inference_mismatch"], - "loss": ["optimization"], - "entropy_loss": ["optimization"], - "kl_loss": ["optimization"], - "grad_norm": ["optimization"], + "train_rollout_logprob_diff": ["train_inference_mismatch"], + "tis": ["train_inference_mismatch"], + "tis_abs": ["train_inference_mismatch"], + "tis_clipfrac": ["train_inference_mismatch"], + "loss": ["optimization"], + "entropy_loss": ["optimization"], + "kl_loss": ["optimization"], + "grad_norm": ["optimization"], } # Maps rollout batch field names to their W&B top-level section. _ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = { - "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time + "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time "rollout_log_probs": "train_inference_mismatch", # inference engine log probs - "ref_log_probs": "policy_shift", # reference model log probs - "rewards": "reward", - "raw_reward": "reward", - "advantages": "reward", - "returns": "reward", + "ref_log_probs": "policy_shift", # reference model log probs + "rewards": "reward", + "raw_reward": "reward", + "advantages": "reward", + "returns": "reward", } @@ -533,7 +533,7 @@ def log_train_step( for full_key, val in log_dict_out.items(): if not full_key.startswith(prefix): continue - bare_key = full_key[len(prefix):] + bare_key = full_key[len(prefix) :] if bare_key in _TRAIN_METRIC_GROUPS: for group in _TRAIN_METRIC_GROUPS[bare_key]: grouped_additions[f"{group}/{bare_key}"] = val diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index e0eccde8b4..a62e4f2678 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -693,7 +693,9 @@ def policy_loss_function( if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + train_rollout_logprob_abs_diff = ( + sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + ) # signed: log π(inf) − log π(fsdp rollout) train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 17905500a3..4a781c7491 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1400,9 +1400,7 @@ def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) -> } -def _compute_group_outcome_metrics( - args, all_samples: list[Sample], prefix: str = "reward" -) -> dict: +def _compute_group_outcome_metrics(args, all_samples: list[Sample], prefix: str = "reward") -> dict: """Fraction of prompt groups that are unanimously correct or incorrect. GRPO only.""" if args.advantage_estimator == "ppo": return {} diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 6c328719a1..7ba101ac7c 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -247,5 +247,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None if sample.loss_mask is not None: sample.loss_mask = sample.loss_mask[:keep_tokens] if sample.rollout_routed_experts is not None: - sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1] + sample.rollout_routed_experts = sample.rollout_routed_experts[: len(sample.tokens) - 1] sample.status = Sample.Status.TRUNCATED diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index 31acbeec3c..fb34b6a0c7 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -340,10 +340,7 @@ def _evict_stale_sessions(self) -> None: if not self._session_last_access: return now = time.monotonic() - stale = [ - sid for sid, ts in self._session_last_access.items() - if now - ts > self._SESSION_TTL_SECS - ] + stale = [sid for sid, ts in self._session_last_access.items() if now - ts > self._SESSION_TTL_SECS] for sid in stale: self.sessions.pop(sid, None) self._session_last_access.pop(sid, None) diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index 8e19b1ba67..e4a003a730 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -123,9 +123,7 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs): _, sorted_free = masked_scores.sort(dim=1, descending=True) # The k-th -1 slot in each row gets sorted_free[row, k]. pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1 - fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to( - top_indices.dtype - ) + fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype) top_indices = torch.where(padding_mask, fill_values, top_indices) if return_probs: diff --git a/train_async.py b/train_async.py index dd1e8f6a03..78a5d0ff32 100644 --- a/train_async.py +++ b/train_async.py @@ -77,5 +77,6 @@ async def train(args): args = parse_args() if getattr(args, "enable_r3_correctness_check", False): from miles.utils.replay_base import RoutingReplayManager + RoutingReplayManager.enable_check_replay_result = True asyncio.run(train(args)) From 0854adccff947f9b920275ea5c84f824fbc7fd8d Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Mon, 18 May 2026 22:48:32 -0700 Subject: [PATCH 3/5] replay_base: direct-evidence logs for R3 wrapper + overlap check The previous --enable-r3-correctness-check flag turned on the overlap check but produced no log output unless an actual mismatch happened, making it impossible to distinguish "check passed" from "check never ran." Add two unconditional logs gated on enable_check_replay_result: 1. get_topk_fn / new_topk_fn replay_forward + replay_backward branches: log when the wrapper actually returns replay indices rather than falling through to old_topk_fn. Direct evidence megatron's MoE forward used the rollout indices (vs recomputing them). 2. check_replay_result: log n_tokens and mismatch_count on every call, including the mismatch_count==0 case (which previously returned silently). Direct evidence the check ran, plus the actual overlap number for cross-step / cross-rank comparison. Both logs gated on enable_check_replay_result so production training runs (which leave it False) stay quiet. Adds no overhead when off. Intended to make the LLM360/RL360 R3 regression E2E able to assert directly that R3 worked end-to-end, rather than inferring from absence of failure messages. --- miles/utils/replay_base.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index e4a003a730..d74350c7ff 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -151,10 +151,32 @@ def new_topk_fn(scores, topk, *args, **kwargs): return result elif stage == "replay_forward": - return _get_replay_result(replay.pop_forward(), scores, topk, *args, **kwargs) + replay_idx = replay.pop_forward() + if manager.enable_check_replay_result: + # Direct evidence the wrapper's replay_forward branch ran in + # the megatron MoE forward path (vs falling through to + # old_topk_fn). Gated on the correctness-check toggle so + # production training stays quiet. replay_idx_sum is a cheap + # fingerprint of the indices for cross-rank / cross-step + # comparison. + logger.info( + f"R3 wrapper: replay_forward branch taken " + f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, " + f"topk={replay_idx.shape[1]}, " + f"replay_idx_sum={int(replay_idx.sum().item())})" + ) + return _get_replay_result(replay_idx, scores, topk, *args, **kwargs) elif stage == "replay_backward": - return _get_replay_result(replay.pop_backward(), scores, topk, *args, **kwargs) + replay_idx = replay.pop_backward() + if manager.enable_check_replay_result: + logger.info( + f"R3 wrapper: replay_backward branch taken " + f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, " + f"topk={replay_idx.shape[1]}, " + f"replay_idx_sum={int(replay_idx.sum().item())})" + ) + return _get_replay_result(replay_idx, scores, topk, *args, **kwargs) else: return old_topk_fn(scores, topk, *args, **kwargs) @@ -196,6 +218,15 @@ def check_replay_result(self, old_topk_fn, scores, topk, top_indices, *args, **k is_mismatch = ~has_overlap & ~is_padding mismatch_count = is_mismatch.sum().item() + n_tokens = orig_flat.shape[0] + # Unconditional log so we have direct evidence the check actually + # ran (its silent return on mismatch_count==0 is otherwise + # indistinguishable from never being called). + logger.info( + f"R3 check (rank {_get_rank()}, stage {self.stage}): " + f"n_tokens={n_tokens} mismatch={mismatch_count} " + f"({100 * mismatch_count / max(n_tokens, 1):.2f}%)" + ) if mismatch_count == 0: return From f019a625a7cb920185e6321ea6f8b60e4ec0ea84 Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Tue, 19 May 2026 09:08:04 -0700 Subject: [PATCH 4/5] actor: make --enable-r3-correctness-check independent of --ci-test actor.py:111-112 unconditionally set m.enable_check_replay_result = m.enabled and self.args.ci_test which overrode the value we set in train_async.py from --enable-r3-correctness-check. The flag was effectively a no-op. This change keeps backward-compat for --ci-test and ALSO honors --enable-r3-correctness-check on its own, so callers can enable the R3 overlap check without enabling the rest of --ci-test's invariants. In particular --ci-test also enables a strict log_probs == ref_log_probs equality check that trips on routine floating-point precision differences (~1e-3 gap), so R3 callers need a way to opt into ONLY the replay check. Found during R3 E2E pre-merge validation: with --ci-test on, the R3 overlap check fired cleanly (1976+ checks, all mismatch=0%) but the job then failed at the unrelated log_probs assertion before the backward pass. With --enable-r3-correctness-check now wired through properly, the same run reaches backward and can show replay_backward branch evidence too. --- miles/backends/megatron_utils/actor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 2cfd373f49..d307ced112 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -109,7 +109,14 @@ def init( else: for m in all_replay_managers: m.enabled = getattr(self.args, f"use_{m.name}_replay") - m.enable_check_replay_result = m.enabled and self.args.ci_test + # Let --enable-r3-correctness-check turn on the replay-check + # without dragging in the other --ci-test invariants (which + # include a strict log_probs vs ref_log_probs equality that + # trips on routine floating-point precision differences). + m.enable_check_replay_result = m.enabled and ( + self.args.ci_test + or getattr(self.args, "enable_r3_correctness_check", False) + ) (self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer( args, role From e06a6b3f71ae5c9860aa92cba34f4e41c2d0ccfc Mon Sep 17 00:00:00 2001 From: David <12414531+DavidBellamy@users.noreply.github.com> Date: Tue, 19 May 2026 14:29:30 -0700 Subject: [PATCH 5/5] actor: apply black to the new condition (CI fix, no logic change) --- miles/backends/megatron_utils/actor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index d307ced112..4776abe91c 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -114,8 +114,7 @@ def init( # include a strict log_probs vs ref_log_probs equality that # trips on routine floating-point precision differences). m.enable_check_replay_result = m.enabled and ( - self.args.ci_test - or getattr(self.args, "enable_r3_correctness_check", False) + self.args.ci_test or getattr(self.args, "enable_r3_correctness_check", False) ) (self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer(