diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 2cfd373f49..4776abe91c 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -109,7 +109,13 @@ def init( else: for m in all_replay_managers: m.enabled = getattr(self.args, f"use_{m.name}_replay") - m.enable_check_replay_result = m.enabled and self.args.ci_test + # Let --enable-r3-correctness-check turn on the replay-check + # without dragging in the other --ci-test invariants (which + # include a strict log_probs vs ref_log_probs equality that + # trips on routine floating-point precision differences). + m.enable_check_replay_result = m.enabled and ( + self.args.ci_test or getattr(self.args, "enable_r3_correctness_check", False) + ) (self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer( args, role diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 4096a4f8ac..a53f2afd54 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -22,33 +22,33 @@ # Maps bare metric names to their W&B top-level section(s). # Keys appearing in multiple sections (e.g. pg_loss) are emitted under each. _TRAIN_METRIC_GROUPS: dict[str, list[str]] = { - "ppo_kl": ["policy_shift"], - "ois": ["policy_shift"], - "pg_clipfrac": ["policy_shift"], - "pg_loss": ["policy_shift", "optimization"], - "log_probs": ["policy_shift"], # current policy (training forward pass) - "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) - "ref_kl": ["policy_shift"], + "ppo_kl": ["policy_shift"], + "ois": ["policy_shift"], + "pg_clipfrac": ["policy_shift"], + "pg_loss": ["policy_shift", "optimization"], + "log_probs": ["policy_shift"], # current policy (training forward pass) + "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) + "ref_kl": ["policy_shift"], "train_rollout_logprob_abs_diff": ["train_inference_mismatch"], - "train_rollout_logprob_diff": ["train_inference_mismatch"], - "tis": ["train_inference_mismatch"], - "tis_abs": ["train_inference_mismatch"], - "tis_clipfrac": ["train_inference_mismatch"], - "loss": ["optimization"], - "entropy_loss": ["optimization"], - "kl_loss": ["optimization"], - "grad_norm": ["optimization"], + "train_rollout_logprob_diff": ["train_inference_mismatch"], + "tis": ["train_inference_mismatch"], + "tis_abs": ["train_inference_mismatch"], + "tis_clipfrac": ["train_inference_mismatch"], + "loss": ["optimization"], + "entropy_loss": ["optimization"], + "kl_loss": ["optimization"], + "grad_norm": ["optimization"], } # Maps rollout batch field names to their W&B top-level section. _ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = { - "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time + "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time "rollout_log_probs": "train_inference_mismatch", # inference engine log probs - "ref_log_probs": "policy_shift", # reference model log probs - "rewards": "reward", - "raw_reward": "reward", - "advantages": "reward", - "returns": "reward", + "ref_log_probs": "policy_shift", # reference model log probs + "rewards": "reward", + "raw_reward": "reward", + "advantages": "reward", + "returns": "reward", } @@ -533,7 +533,7 @@ def log_train_step( for full_key, val in log_dict_out.items(): if not full_key.startswith(prefix): continue - bare_key = full_key[len(prefix):] + bare_key = full_key[len(prefix) :] if bare_key in _TRAIN_METRIC_GROUPS: for group in _TRAIN_METRIC_GROUPS[bare_key]: grouped_additions[f"{group}/{bare_key}"] = val diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index e0eccde8b4..a62e4f2678 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -693,7 +693,9 @@ def policy_loss_function( if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + train_rollout_logprob_abs_diff = ( + sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + ) # signed: log π(inf) − log π(fsdp rollout) train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 17905500a3..4a781c7491 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1400,9 +1400,7 @@ def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) -> } -def _compute_group_outcome_metrics( - args, all_samples: list[Sample], prefix: str = "reward" -) -> dict: +def _compute_group_outcome_metrics(args, all_samples: list[Sample], prefix: str = "reward") -> dict: """Fraction of prompt groups that are unanimously correct or incorrect. GRPO only.""" if args.advantage_estimator == "ppo": return {} diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 6c328719a1..7ba101ac7c 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -247,5 +247,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None if sample.loss_mask is not None: sample.loss_mask = sample.loss_mask[:keep_tokens] if sample.rollout_routed_experts is not None: - sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1] + sample.rollout_routed_experts = sample.rollout_routed_experts[: len(sample.tokens) - 1] sample.status = Sample.Status.TRUNCATED diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index 31acbeec3c..fb34b6a0c7 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -340,10 +340,7 @@ def _evict_stale_sessions(self) -> None: if not self._session_last_access: return now = time.monotonic() - stale = [ - sid for sid, ts in self._session_last_access.items() - if now - ts > self._SESSION_TTL_SECS - ] + stale = [sid for sid, ts in self._session_last_access.items() if now - ts > self._SESSION_TTL_SECS] for sid in stale: self.sessions.pop(sid, None) self._session_last_access.pop(sid, None) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 717c2b6980..b6d516716c 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1024,6 +1024,19 @@ def add_algo_arguments(parser): default=False, help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370", ) + parser.add_argument( + "--enable-r3-correctness-check", + action="store_true", + default=False, + help=( + "Run RoutingReplayManager's per-step overlap check that " + "recomputes the training-side topk on the same scores and " + "asserts overlap with the rollout indices. Roughly 2x routing " + "cost; off by default. Intended for the R3 regression E2E " + "(LLM360/RL360 scripts/r3-e2e/). No effect unless " + "--use-rollout-routing-replay is also set." + ), + ) parser.add_argument( "--use-opsm", action="store_true", diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index 8e19b1ba67..d74350c7ff 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -123,9 +123,7 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs): _, sorted_free = masked_scores.sort(dim=1, descending=True) # The k-th -1 slot in each row gets sorted_free[row, k]. pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1 - fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to( - top_indices.dtype - ) + fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype) top_indices = torch.where(padding_mask, fill_values, top_indices) if return_probs: @@ -153,10 +151,32 @@ def new_topk_fn(scores, topk, *args, **kwargs): return result elif stage == "replay_forward": - return _get_replay_result(replay.pop_forward(), scores, topk, *args, **kwargs) + replay_idx = replay.pop_forward() + if manager.enable_check_replay_result: + # Direct evidence the wrapper's replay_forward branch ran in + # the megatron MoE forward path (vs falling through to + # old_topk_fn). Gated on the correctness-check toggle so + # production training stays quiet. replay_idx_sum is a cheap + # fingerprint of the indices for cross-rank / cross-step + # comparison. + logger.info( + f"R3 wrapper: replay_forward branch taken " + f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, " + f"topk={replay_idx.shape[1]}, " + f"replay_idx_sum={int(replay_idx.sum().item())})" + ) + return _get_replay_result(replay_idx, scores, topk, *args, **kwargs) elif stage == "replay_backward": - return _get_replay_result(replay.pop_backward(), scores, topk, *args, **kwargs) + replay_idx = replay.pop_backward() + if manager.enable_check_replay_result: + logger.info( + f"R3 wrapper: replay_backward branch taken " + f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, " + f"topk={replay_idx.shape[1]}, " + f"replay_idx_sum={int(replay_idx.sum().item())})" + ) + return _get_replay_result(replay_idx, scores, topk, *args, **kwargs) else: return old_topk_fn(scores, topk, *args, **kwargs) @@ -198,6 +218,15 @@ def check_replay_result(self, old_topk_fn, scores, topk, top_indices, *args, **k is_mismatch = ~has_overlap & ~is_padding mismatch_count = is_mismatch.sum().item() + n_tokens = orig_flat.shape[0] + # Unconditional log so we have direct evidence the check actually + # ran (its silent return on mismatch_count==0 is otherwise + # indistinguishable from never being called). + logger.info( + f"R3 check (rank {_get_rank()}, stage {self.stage}): " + f"n_tokens={n_tokens} mismatch={mismatch_count} " + f"({100 * mismatch_count / max(n_tokens, 1):.2f}%)" + ) if mismatch_count == 0: return diff --git a/train_async.py b/train_async.py index e9e05a4062..78a5d0ff32 100644 --- a/train_async.py +++ b/train_async.py @@ -75,4 +75,8 @@ async def train(args): if __name__ == "__main__": args = parse_args() + if getattr(args, "enable_r3_correctness_check", False): + from miles.utils.replay_base import RoutingReplayManager + + RoutingReplayManager.enable_check_replay_result = True asyncio.run(train(args))