Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def init(
else:
for m in all_replay_managers:
m.enabled = getattr(self.args, f"use_{m.name}_replay")
m.enable_check_replay_result = m.enabled and self.args.ci_test
# Let --enable-r3-correctness-check turn on the replay-check
# without dragging in the other --ci-test invariants (which
# include a strict log_probs vs ref_log_probs equality that
# trips on routine floating-point precision differences).
m.enable_check_replay_result = m.enabled and (
self.args.ci_test or getattr(self.args, "enable_r3_correctness_check", False)
)

(self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer(
args, role
Expand Down
44 changes: 22 additions & 22 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,33 @@
# Maps bare metric names to their W&B top-level section(s).
# Keys appearing in multiple sections (e.g. pg_loss) are emitted under each.
_TRAIN_METRIC_GROUPS: dict[str, list[str]] = {
"ppo_kl": ["policy_shift"],
"ois": ["policy_shift"],
"pg_clipfrac": ["policy_shift"],
"pg_loss": ["policy_shift", "optimization"],
"log_probs": ["policy_shift"], # current policy (training forward pass)
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
"ref_kl": ["policy_shift"],
"ppo_kl": ["policy_shift"],
"ois": ["policy_shift"],
"pg_clipfrac": ["policy_shift"],
"pg_loss": ["policy_shift", "optimization"],
"log_probs": ["policy_shift"], # current policy (training forward pass)
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
"ref_kl": ["policy_shift"],
"train_rollout_logprob_abs_diff": ["train_inference_mismatch"],
"train_rollout_logprob_diff": ["train_inference_mismatch"],
"tis": ["train_inference_mismatch"],
"tis_abs": ["train_inference_mismatch"],
"tis_clipfrac": ["train_inference_mismatch"],
"loss": ["optimization"],
"entropy_loss": ["optimization"],
"kl_loss": ["optimization"],
"grad_norm": ["optimization"],
"train_rollout_logprob_diff": ["train_inference_mismatch"],
"tis": ["train_inference_mismatch"],
"tis_abs": ["train_inference_mismatch"],
"tis_clipfrac": ["train_inference_mismatch"],
"loss": ["optimization"],
"entropy_loss": ["optimization"],
"kl_loss": ["optimization"],
"grad_norm": ["optimization"],
}

# Maps rollout batch field names to their W&B top-level section.
_ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = {
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
"rollout_log_probs": "train_inference_mismatch", # inference engine log probs
"ref_log_probs": "policy_shift", # reference model log probs
"rewards": "reward",
"raw_reward": "reward",
"advantages": "reward",
"returns": "reward",
"ref_log_probs": "policy_shift", # reference model log probs
"rewards": "reward",
"raw_reward": "reward",
"advantages": "reward",
"returns": "reward",
}


Expand Down Expand Up @@ -533,7 +533,7 @@ def log_train_step(
for full_key, val in log_dict_out.items():
if not full_key.startswith(prefix):
continue
bare_key = full_key[len(prefix):]
bare_key = full_key[len(prefix) :]
if bare_key in _TRAIN_METRIC_GROUPS:
for group in _TRAIN_METRIC_GROUPS[bare_key]:
grouped_additions[f"{group}/{bare_key}"] = val
Expand Down
4 changes: 3 additions & 1 deletion miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,9 @@ def policy_loss_function(
if "rollout_log_probs" in batch and batch["rollout_log_probs"]:
rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0)
log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0)
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
train_rollout_logprob_abs_diff = (
sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
)
# signed: log π(inf) − log π(fsdp rollout)
train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach()

Expand Down
4 changes: 1 addition & 3 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,9 +1400,7 @@ def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) ->
}


def _compute_group_outcome_metrics(
args, all_samples: list[Sample], prefix: str = "reward"
) -> dict:
def _compute_group_outcome_metrics(args, all_samples: list[Sample], prefix: str = "reward") -> dict:
"""Fraction of prompt groups that are unanimously correct or incorrect. GRPO only."""
if args.advantage_estimator == "ppo":
return {}
Expand Down
2 changes: 1 addition & 1 deletion miles/rollout/generate_utils/openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None
if sample.loss_mask is not None:
sample.loss_mask = sample.loss_mask[:keep_tokens]
if sample.rollout_routed_experts is not None:
sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1]
sample.rollout_routed_experts = sample.rollout_routed_experts[: len(sample.tokens) - 1]
sample.status = Sample.Status.TRUNCATED
5 changes: 1 addition & 4 deletions miles/rollout/session/linear_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ def _evict_stale_sessions(self) -> None:
if not self._session_last_access:
return
now = time.monotonic()
stale = [
sid for sid, ts in self._session_last_access.items()
if now - ts > self._SESSION_TTL_SECS
]
stale = [sid for sid, ts in self._session_last_access.items() if now - ts > self._SESSION_TTL_SECS]
for sid in stale:
self.sessions.pop(sid, None)
self._session_last_access.pop(sid, None)
Expand Down
13 changes: 13 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,19 @@ def add_algo_arguments(parser):
default=False,
help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370",
)
parser.add_argument(
"--enable-r3-correctness-check",
action="store_true",
default=False,
help=(
"Run RoutingReplayManager's per-step overlap check that "
"recomputes the training-side topk on the same scores and "
"asserts overlap with the rollout indices. Roughly 2x routing "
"cost; off by default. Intended for the R3 regression E2E "
"(LLM360/RL360 scripts/r3-e2e/). No effect unless "
"--use-rollout-routing-replay is also set."
),
)
parser.add_argument(
"--use-opsm",
action="store_true",
Expand Down
39 changes: 34 additions & 5 deletions miles/utils/replay_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs):
_, sorted_free = masked_scores.sort(dim=1, descending=True)
# The k-th -1 slot in each row gets sorted_free[row, k].
pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(
top_indices.dtype
)
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype)
top_indices = torch.where(padding_mask, fill_values, top_indices)

if return_probs:
Expand Down Expand Up @@ -153,10 +151,32 @@ def new_topk_fn(scores, topk, *args, **kwargs):
return result

elif stage == "replay_forward":
return _get_replay_result(replay.pop_forward(), scores, topk, *args, **kwargs)
replay_idx = replay.pop_forward()
if manager.enable_check_replay_result:
# Direct evidence the wrapper's replay_forward branch ran in
# the megatron MoE forward path (vs falling through to
# old_topk_fn). Gated on the correctness-check toggle so
# production training stays quiet. replay_idx_sum is a cheap
# fingerprint of the indices for cross-rank / cross-step
# comparison.
logger.info(
f"R3 wrapper: replay_forward branch taken "
f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, "
f"topk={replay_idx.shape[1]}, "
f"replay_idx_sum={int(replay_idx.sum().item())})"
)
return _get_replay_result(replay_idx, scores, topk, *args, **kwargs)

elif stage == "replay_backward":
return _get_replay_result(replay.pop_backward(), scores, topk, *args, **kwargs)
replay_idx = replay.pop_backward()
if manager.enable_check_replay_result:
logger.info(
f"R3 wrapper: replay_backward branch taken "
f"(rank {_get_rank()}, n_tokens={replay_idx.shape[0]}, "
f"topk={replay_idx.shape[1]}, "
f"replay_idx_sum={int(replay_idx.sum().item())})"
)
return _get_replay_result(replay_idx, scores, topk, *args, **kwargs)

else:
return old_topk_fn(scores, topk, *args, **kwargs)
Expand Down Expand Up @@ -198,6 +218,15 @@ def check_replay_result(self, old_topk_fn, scores, topk, top_indices, *args, **k
is_mismatch = ~has_overlap & ~is_padding

mismatch_count = is_mismatch.sum().item()
n_tokens = orig_flat.shape[0]
# Unconditional log so we have direct evidence the check actually
# ran (its silent return on mismatch_count==0 is otherwise
# indistinguishable from never being called).
logger.info(
f"R3 check (rank {_get_rank()}, stage {self.stage}): "
f"n_tokens={n_tokens} mismatch={mismatch_count} "
f"({100 * mismatch_count / max(n_tokens, 1):.2f}%)"
)
if mismatch_count == 0:
return

Expand Down
4 changes: 4 additions & 0 deletions train_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ async def train(args):

if __name__ == "__main__":
args = parse_args()
if getattr(args, "enable_r3_correctness_check", False):
from miles.utils.replay_base import RoutingReplayManager

RoutingReplayManager.enable_check_replay_result = True
asyncio.run(train(args))
Loading