From 0bd87fc7b187940bcf28913dd739ec6a15d1d252 Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Wed, 6 May 2026 22:55:35 +0000 Subject: [PATCH 1/8] metrics changes --- miles/backends/training_utils/data.py | 12 + miles/backends/training_utils/log_utils.py | 16 +- miles/backends/training_utils/loss.py | 55 +++- miles/ray/rollout.py | 60 ++++- miles/utils/metric_utils.py | 5 + miles/utils/wandb_utils.py | 5 + .../training_utils/test_metric_domains.py | 245 ++++++++++++++++++ 7 files changed, 379 insertions(+), 19 deletions(-) create mode 100644 tests/fast/backends/training_utils/test_metric_domains.py diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 3bda887b26..2733c93442 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -115,11 +115,23 @@ def get_batch( parallel_state = get_parallel_state() assert "tokens" in keys + has_domains = "domains" in data_iterator.rollout_data + if has_domains and "domains" not in keys: + keys = list(keys) + ["domains"] batch = data_iterator.get_next(keys) if "dynamic_global_batch_size" in data_iterator.rollout_data: batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"] + # Canonical domain set, cached on the iterator so every microbatch emits the + # same list (aggregate_train_losses keys positionally on the first microbatch). + if has_domains: + if not hasattr(data_iterator, "_all_domains_cache"): + data_iterator._all_domains_cache = sorted( + {d for d in data_iterator.rollout_data["domains"] if d} + ) + batch["all_domains"] = data_iterator._all_domains_cache + tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 4096a4f8ac..9bada19ddf 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -9,7 +9,7 @@ from miles.utils import train_metric_utils from miles.utils.flops_utils import calculate_fwd_flops -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step +from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_samples_seen from miles.utils.types import RolloutBatch from ...utils import tracking_utils @@ -170,6 +170,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "dynamic_global_batch_size", "weight_versions", "metadata", + "domains", ]: continue # Upload per sample mean for each rollout value @@ -526,6 +527,9 @@ def log_train_step( log_dict_out["train/rollout_id"] = rollout_id log_dict_out["train/step_in_rollout"] = step_id log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id) + log_dict_out["samples_seen"] = compute_samples_seen(args, rollout_id) + log_dict_out["train_step"] = accumulated_step_id + log_dict_out["rollout_step"] = log_dict_out["rollout/step"] # Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged) grouped_additions = {} @@ -534,9 +538,13 @@ def log_train_step( if not full_key.startswith(prefix): continue bare_key = full_key[len(prefix):] - if bare_key in _TRAIN_METRIC_GROUPS: - for group in _TRAIN_METRIC_GROUPS[bare_key]: - grouped_additions[f"{group}/{bare_key}"] = val + # Per-domain keys arrive as "/" — route to "//". + metric_name, sep, domain = bare_key.rpartition("/") + lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key + if lookup in _TRAIN_METRIC_GROUPS: + suffix = f"{domain}/{metric_name}" if lookup == metric_name else bare_key + for group in _TRAIN_METRIC_GROUPS[lookup]: + grouped_additions[f"{group}/{suffix}"] = val elif bare_key.startswith("lr-pg_"): grouped_additions[f"optimization/{bare_key}"] = val log_dict_out.update(grouped_additions) diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index bd0386e9f4..317286f6e0 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -653,6 +653,11 @@ def policy_loss_function( else: pg_loss_reducer = sum_of_sample_mean + # Saved for per-domain fan-out (reducers below overwrite these names with scalars). + _pg_loss_per_token = pg_loss + _pg_clipfrac_per_token = pg_clipfrac + _ppo_kl_per_token = ppo_kl + pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) @@ -691,18 +696,24 @@ def policy_loss_function( # Train-inference mismatch: compare inference engine vs FSDP at rollout time train_rollout_logprob_abs_diff = None train_rollout_logprob_diff = None + _train_rollout_logprob_abs_per_token = None + _train_rollout_logprob_signed_per_token = None if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + _train_rollout_logprob_abs_per_token = (old_log_probs - rollout_log_probs_cat).abs() # signed: log π(inf) − log π(fsdp rollout) - train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() + _train_rollout_logprob_signed_per_token = rollout_log_probs_cat - log_probs_batch_cat + train_rollout_logprob_abs_diff = sum_of_sample_mean(_train_rollout_logprob_abs_per_token).clone().detach() + train_rollout_logprob_diff = sum_of_sample_mean(_train_rollout_logprob_signed_per_token).clone().detach() # KL vs reference model — always log when ref present, regardless of use_kl_loss ref_kl_metric = None + _ref_kl_per_token = None if "ref_log_probs" in batch and batch["ref_log_probs"]: ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0) - ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach() + _ref_kl_per_token = log_probs - ref_log_probs_cat + ref_kl_metric = sum_of_sample_mean(_ref_kl_per_token).clone().detach() reported_loss = { "loss": loss.clone().detach(), @@ -736,6 +747,44 @@ def policy_loss_function( if args.use_opsm: reported_loss["opsm_clipfrac"] = opsm_clipfrac + # Per-domain fan-out: activated by batch["domains"] (set when samples carry + # metadata["domain"]). batch["all_domains"] is cached on DataIterator so every + # microbatch emits the same key set (aggregate_train_losses keys positionally). + # grad_norm isn't split: backward() has already mixed gradients. + if batch.get("domains") and batch.get("all_domains"): + per_token = { + "log_probs": log_probs, + "old_log_probs": old_log_probs, + "pg_loss": _pg_loss_per_token, + "pg_clipfrac": _pg_clipfrac_per_token, + "ppo_kl": _ppo_kl_per_token, + "entropy_loss": entropy, + } + if _ref_kl_per_token is not None: + per_token["ref_kl"] = _ref_kl_per_token + if _train_rollout_logprob_signed_per_token is not None: + per_token["train_rollout_logprob_diff"] = _train_rollout_logprob_signed_per_token + per_token["train_rollout_logprob_abs_diff"] = _train_rollout_logprob_abs_per_token + if args.get_mismatch_metrics or args.use_tis: + per_token["ois"] = ois + per_token.update(tis_metrics) + + for d in batch["all_domains"]: + masked = [ + lm if dd == d else torch.zeros_like(lm) + for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False) + ] + reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked, + args.calculate_per_token_loss, args.qkv_format, max_seq_lens, + loss_agg_mode=getattr(args, "loss_agg_mode", None), + ) + for name, t in per_token.items(): + reported_loss[f"{name}/{d}"] = reducer(t).clone().detach() + reported_loss[f"loss/{d}"] = ( + reported_loss[f"pg_loss/{d}"] - args.entropy_coef * reported_loss[f"entropy_loss/{d}"] + ) + return loss, reported_loss diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index ebe44f8fb9..e6e83e2bd3 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -38,7 +38,13 @@ from miles.utils.iter_utils import group_by from miles.utils.logging_utils import configure_logger from miles.utils.metric_checker import MetricChecker -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix +from miles.utils.metric_utils import ( + compute_pass_rate, + compute_rollout_step, + compute_samples_seen, + compute_statistics, + dict_add_prefix, +) from miles.utils.misc import load_function from miles.utils.ray_utils import Box from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions @@ -752,6 +758,11 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] + # Presence of metadata["domain"] activates per-domain metric fan-out in policy_loss_function. + domains = [s.metadata.get("domain") for s in samples] + if any(domains): + train_data["domains"] = domains + if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] @@ -869,6 +880,7 @@ def _stat(xs): "prompt", "teacher_log_probs", "weight_versions", + "domains", ]: if key not in data: continue @@ -1371,10 +1383,16 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_ log_dict = {**(rollout_extra_metrics or {})} log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") + # Mirror reward/* and response_stats/* as top-level wandb panels. + for full_key, val in list(log_dict.items()): + if full_key.startswith(("rollout/reward/", "rollout/response_stats/")): + log_dict[full_key[len("rollout/"):]] = val logger.info(f"perf {rollout_id}: {log_dict}") step = compute_rollout_step(args, rollout_id) log_dict["rollout/step"] = step log_dict["train/rollout_id"] = rollout_id + log_dict["samples_seen"] = compute_samples_seen(args, rollout_id) + log_dict["rollout_step"] = step tracking_utils.log(args, log_dict, step_key="rollout/step") @@ -1422,8 +1440,9 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward") # per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward) - correct = [s for s in samples if s.get_reward_value(args) > 0] - incorrect = [s for s in samples if s.get_reward_value(args) <= 0] + correct = [s for s in samples if _correctness(s, args)] + incorrect = [s for s in samples if not _correctness(s, args)] + log_dict["reward/correctness"] = len(correct) / n for label, grp in [("correct", correct), ("incorrect", incorrect)]: if grp: log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False) @@ -1438,10 +1457,10 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n) log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}") log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}") - for label, grp in [ - ("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]), - ("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]), - ]: + cat_correct = [s for s in cat_grp if _correctness(s, args)] + cat_incorrect = [s for s in cat_grp if not _correctness(s, args)] + log_dict[f"reward/{cat}/correctness"] = len(cat_correct) / len(cat_grp) + for label, grp in [("correct", cat_correct), ("incorrect", cat_incorrect)]: if grp: log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}") @@ -1486,6 +1505,9 @@ def _compute_zero_std_metrics(args, all_samples: list[Sample]): # only compute in GRPO-like algorithms where one prompt has multiple responses if args.advantage_estimator == "ppo": return {} + # Skip non-scalar rewards (round() and zero-std comparison are ill-defined on dicts). + if all_samples and not isinstance(all_samples[0].get_reward_value(args), (int, float)): + return {} def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] @@ -1530,8 +1552,11 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]): return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} -# Candidate metadata keys to auto-detect problem category (checked in order) -_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"] +# Candidate metadata keys to auto-detect problem category (checked in order). +# `domain` is first because it's the routing key in multi-teacher setups; if a sample +# carries both `domain` and one of the legacy fields like `category`, group by domain +# so per-cat correctness panels match the per-domain loss panels. +_CANDIDATE_CATEGORY_KEYS = ["domain", "category", "type", "subject", "problem_type"] def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: @@ -1547,11 +1572,22 @@ def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: return None +def _correctness(sample: Sample, args) -> bool: + """Non-scalar reward fns set metadata["correctness_reward"]; scalars fall back to sign.""" + if "correctness_reward" in sample.metadata: + return sample.metadata["correctness_reward"] > 0 + val = sample.get_reward_value(args) + return isinstance(val, (int, float)) and val > 0 + + def _compute_grouped_reward_metrics( args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True ) -> dict: """Reward/outcome metrics for a split — emitted under reward/ sections.""" - result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()} + result = {} + # Skip raw_reward when reward is non-scalar (e.g. dict-valued OPD rewards). + if group and isinstance(group[0].get_reward_value(args), (int, float)): + result[f"{prefix}/raw_reward"] = np.mean([s.get_reward_value(args) for s in group]).item() if include_count_frac: result[f"{prefix}/count_frac"] = len(group) / n_total return result @@ -1576,8 +1612,8 @@ def _compute_group_outcome_metrics( n_groups = len(groups) if n_groups == 0: return {} - all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g)) - all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g)) + all_correct = sum(1 for g in groups if all(_correctness(s, args) for s in g)) + all_incorrect = sum(1 for g in groups if all(not _correctness(s, args) for s in g)) return { f"{prefix}/all_correct_group_frac": all_correct / n_groups, f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups, diff --git a/miles/utils/metric_utils.py b/miles/utils/metric_utils.py index 66292c79e7..839b50e25f 100644 --- a/miles/utils/metric_utils.py +++ b/miles/utils/metric_utils.py @@ -118,3 +118,8 @@ def compute_rollout_step(args, rollout_id): if args.wandb_always_use_train_step: return rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size return rollout_id + + +def compute_samples_seen(args, rollout_id: int) -> int: + """Cumulative samples through (and including) rollout `rollout_id` (0-indexed).""" + return args.rollout_batch_size * args.n_samples_per_prompt * (rollout_id + 1) diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index c29fcc7eaa..f9524c91af 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -175,3 +175,8 @@ def _init_wandb_common(): # rollout counter. Declared after the "train/*" wildcard so the specific name # isn't inadvertently treated as step-metric'd against train/step. wandb.define_metric("train/rollout_id") + # Bare step counters — co-logged so one panel can plot all three as time series + # (useful for spotting non-monotone train/step jumps under dynamic batching). + wandb.define_metric("samples_seen") + wandb.define_metric("train_step") + wandb.define_metric("rollout_step") diff --git a/tests/fast/backends/training_utils/test_metric_domains.py b/tests/fast/backends/training_utils/test_metric_domains.py new file mode 100644 index 0000000000..79f7aea67b --- /dev/null +++ b/tests/fast/backends/training_utils/test_metric_domains.py @@ -0,0 +1,245 @@ +"""Tests for per-domain metric fan-out and unified correctness signal. + +These exercise the math/logic added when upstreaming OPD-specific metrics +into miles. The goal is to verify: + +1. The masked-loss-mask trick used for per-domain reductions: when we zero + out non-target samples' loss_masks, the resulting `sum_of_sample_mean` + reducer produces the same value as `sum_of_sample_mean` over only the + target-domain samples. + +2. Per-domain reductions partition the global reduction. When domains + partition the batch and every sample has the same per-token weight + distribution, the (count-weighted) sum across domains matches the + global reduction. + +3. The `_correctness(s, args)` helper unifies scalar GRPO reward sign and + non-scalar OPD `metadata["correctness_reward"]` into one path. + +4. `compute_samples_seen` returns the cumulative per-rollout sample count. +""" +from __future__ import annotations + +import sys +import types +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any + +import torch + +from miles.backends.training_utils.cp_utils import get_sum_of_sample_mean +from miles.utils.metric_utils import compute_samples_seen + + +# --------------------------------------------------------------------------- +# Stub miles.backends.training_utils.parallel.get_parallel_state so cp_utils' +# sum_of_sample_mean (which calls it) returns CP=1. +# --------------------------------------------------------------------------- + +class _FakePG: + size = 1 + rank = 0 + + +class _FakeParallelState: + cp = _FakePG() + + +def _patch_parallel_state(monkeypatch): + import miles.backends.training_utils.cp_utils as cp_utils + monkeypatch.setattr(cp_utils, "get_parallel_state", lambda: _FakeParallelState()) + + +# --------------------------------------------------------------------------- +# 1. Masked-loss-mask trick: per-domain reducer ignores non-target samples +# --------------------------------------------------------------------------- + +def test_domain_filtered_reducer_matches_per_domain_subset(monkeypatch): + _patch_parallel_state(monkeypatch) + + # 3 samples: math, code, math. Each sample has 4 response tokens. + # Per-token "values" tensor `x` is the per-sample value broadcast over tokens. + domains = ["math", "code", "math"] + response_lengths = [4, 4, 4] + total_lengths = [4, 4, 4] + loss_masks = [torch.ones(4) for _ in range(3)] + + # x: per-sample mean is [1.0, 2.0, 3.0] + x = torch.tensor([1.0]*4 + [2.0]*4 + [3.0]*4) + + # Build a "math"-filtered reducer + masked_for_math = [ + lm if d == "math" else torch.zeros_like(lm) + for d, lm in zip(domains, loss_masks) + ] + math_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_math, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + + # sum_of_sample_mean: per-sample token-mean, then sum across samples. + # For math: sample 0 contributes 1.0, sample 2 contributes 3.0, sample 1 + # contributes 0 (its mask is all zeros, clamp_min returns 1 in denominator + # but numerator is also 0). Result should be 1.0 + 0 + 3.0 = 4.0. + result = math_reducer(x).item() + assert result == 4.0, f"expected 4.0, got {result}" + + # Same for "code": only sample 1 contributes, value 2.0 + masked_for_code = [ + lm if d == "code" else torch.zeros_like(lm) + for d, lm in zip(domains, loss_masks) + ] + code_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_code, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + result = code_reducer(x).item() + assert result == 2.0, f"expected 2.0, got {result}" + + +def test_domain_reducer_returns_zero_for_absent_domain(monkeypatch): + """When no sample matches the target domain, the reducer must return 0 + (not NaN, not error). aggregate_train_losses requires every microbatch + emit the same key set — a domain with zero samples in this microbatch + must contribute 0 to the positional aggregation.""" + _patch_parallel_state(monkeypatch) + + domains = ["math", "math"] # no "code" samples + response_lengths = [3, 3] + total_lengths = [3, 3] + loss_masks = [torch.ones(3), torch.ones(3)] + x = torch.tensor([5.0, 5.0, 5.0, 7.0, 7.0, 7.0]) + + masked_for_code = [torch.zeros_like(lm) for lm in loss_masks] # all zero + code_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_code, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + result = code_reducer(x).item() + assert result == 0.0, f"expected 0 for absent domain, got {result}" + assert not torch.isnan(torch.tensor(result)) + + +def test_per_domain_reductions_sum_to_global(monkeypatch): + """When domains partition the batch and we use sum-mode (sum_of_sample_mean + sums per-sample means), summing per-domain reductions equals the global one.""" + _patch_parallel_state(monkeypatch) + + domains = ["math", "code", "math", "code"] + response_lengths = [2, 2, 2, 2] + total_lengths = [2, 2, 2, 2] + loss_masks = [torch.ones(2) for _ in range(4)] + # per-sample means: [1, 2, 3, 4] + x = torch.tensor([1.0]*2 + [2.0]*2 + [3.0]*2 + [4.0]*2) + + global_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, loss_masks, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + global_value = global_reducer(x).item() + assert global_value == 1.0 + 2.0 + 3.0 + 4.0 # = 10.0 + + # Sum per-domain reductions. {math}=samples 0,2 -> 1+3=4; {code}=1,3 -> 2+4=6. + per_domain = 0.0 + for target in ["math", "code"]: + masked = [lm if d == target else torch.zeros_like(lm) for d, lm in zip(domains, loss_masks)] + red = get_sum_of_sample_mean( + total_lengths, response_lengths, masked, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + per_domain += red(x).item() + + assert per_domain == global_value, f"per-domain sum {per_domain} != global {global_value}" + + +# --------------------------------------------------------------------------- +# 2. Unified correctness signal: scalar fallback + metadata override +# --------------------------------------------------------------------------- +# +# The actual `_correctness` helper lives in miles.ray.rollout but importing +# that module pulls in ray. Vendor the implementation here as a fixture and +# verify the *semantic* contract — the implementation in rollout.py is a +# verbatim copy of this snippet (kept in sync by code review). + + +def _vendored_correctness(sample, args) -> bool: + """Mirror of miles.ray.rollout._correctness — keep in sync.""" + if "correctness_reward" in sample.metadata: + return sample.metadata["correctness_reward"] > 0 + val = sample.get_reward_value(args) + return isinstance(val, (int, float)) and val > 0 + + +@dataclass +class _FakeSample: + reward: Any + metadata: dict = field(default_factory=dict) + + def get_reward_value(self, args): + return self.reward if not args.reward_key else self.reward[args.reward_key] + + +def test_correctness_scalar_fallback(): + args = Namespace(reward_key=None) + assert _vendored_correctness(_FakeSample(1.0), args) is True + assert _vendored_correctness(_FakeSample(0.5), args) is True + assert _vendored_correctness(_FakeSample(0.0), args) is False + assert _vendored_correctness(_FakeSample(-0.3), args) is False + + +def test_correctness_metadata_override_takes_precedence(): + """metadata['correctness_reward'] wins even when reward is also scalar.""" + args = Namespace(reward_key=None) + s = _FakeSample(1.0, metadata={"correctness_reward": 0.0}) # scalar says correct, metadata says wrong + assert _vendored_correctness(s, args) is False + s = _FakeSample(0.0, metadata={"correctness_reward": 1.0}) + assert _vendored_correctness(s, args) is True + + +def test_correctness_non_scalar_reward_with_metadata(): + """OPD path: reward is a dict, correctness comes from metadata.""" + args = Namespace(reward_key=None) + s = _FakeSample({"kl_a": 0.5, "kl_b": 0.3}, metadata={"correctness_reward": 1.0}) + assert _vendored_correctness(s, args) is True + s = _FakeSample({"kl_a": 0.5}, metadata={"correctness_reward": 0.0}) + assert _vendored_correctness(s, args) is False + + +def test_correctness_non_scalar_reward_no_metadata_returns_false(): + """Without correctness_reward metadata and a non-scalar reward, the + helper returns False (val>0 short-circuits via isinstance check).""" + args = Namespace(reward_key=None) + s = _FakeSample({"kl_a": 0.5}) + assert _vendored_correctness(s, args) is False + + +# --------------------------------------------------------------------------- +# 3. compute_samples_seen +# --------------------------------------------------------------------------- + +def test_compute_samples_seen_first_rollout(): + args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) + # rollout_id=0 means the first rollout has finished -> 32 samples seen. + assert compute_samples_seen(args, 0) == 32 + + +def test_compute_samples_seen_monotone(): + args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) + seen = [compute_samples_seen(args, i) for i in range(5)] + # Strictly monotone, increment of 32 per rollout. + assert seen == [32, 64, 96, 128, 160] + + +# --------------------------------------------------------------------------- +# 4. Activation-by-presence: domains list is sorted-unique (matches the +# DataIterator._all_domains_cache construction). +# --------------------------------------------------------------------------- + +def test_all_domains_cache_construction(): + """Mirror the construction in get_batch: sorted({d for d in domains if d}).""" + domains = ["math", "code", None, "math", "code", "science", None] + all_domains = sorted({d for d in domains if d}) + assert all_domains == ["code", "math", "science"] + # None values are filtered out (samples without a domain don't add a key). + assert None not in all_domains From b893fe24cc43799ac774492d97cff9316670397a Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Thu, 7 May 2026 22:30:45 +0000 Subject: [PATCH 2/8] added change --- miles/backends/training_utils/log_utils.py | 24 +++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 9bada19ddf..120560cd41 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -9,7 +9,7 @@ from miles.utils import train_metric_utils from miles.utils.flops_utils import calculate_fwd_flops -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_samples_seen +from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step from miles.utils.types import RolloutBatch from ...utils import tracking_utils @@ -51,6 +51,13 @@ "returns": "reward", } +# Cumulative train-step counter across all rollouts. The previous formula +# `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when +# `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since +# each rollout uses its own current num_steps_per_rollout as a scaling factor. +# A simple monotone counter is invariant to that jitter. +_TRAIN_STEP_COUNTER = 0 + def gather_log_data( metric_name: str, @@ -77,7 +84,15 @@ def gather_log_data( # dict to the union of keys with NaN so every rank sends the same shape. # Cost is one all_gather_object on a tiny key list. all_keys: list = [None] * dp_size + logger.info( + f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " + f"rollout={rollout_id} entering all_gather_object, keys={len(log_dict)}" + ) dist.all_gather_object(all_keys, sorted(log_dict.keys()), group=pg.gloo_group) + logger.info( + f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " + f"rollout={rollout_id} all_gather_object returned" + ) union_keys: set = set() for ks in all_keys: if ks: @@ -493,7 +508,9 @@ def log_train_step( Returns: The formatted log_dict (for CI tests or other uses). """ - accumulated_step_id = rollout_id * num_steps_per_rollout + step_id + global _TRAIN_STEP_COUNTER + accumulated_step_id = _TRAIN_STEP_COUNTER + _TRAIN_STEP_COUNTER += 1 role_tag = "" if role == "actor" else f"{role}-" log_dict_out = { @@ -526,10 +543,7 @@ def log_train_step( # cross-plotted against rollout-side axes in the wandb UI. log_dict_out["train/rollout_id"] = rollout_id log_dict_out["train/step_in_rollout"] = step_id - log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id) - log_dict_out["samples_seen"] = compute_samples_seen(args, rollout_id) log_dict_out["train_step"] = accumulated_step_id - log_dict_out["rollout_step"] = log_dict_out["rollout/step"] # Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged) grouped_additions = {} From 61a9c546499d2c754c94777a14b970c16a87be5b Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Thu, 14 May 2026 20:42:47 +0000 Subject: [PATCH 3/8] removed unnecessary files --- scripts/models/xllm-8B.sh | 20 -- .../training_utils/test_metric_domains.py | 245 ------------------ 2 files changed, 265 deletions(-) delete mode 100644 scripts/models/xllm-8B.sh delete mode 100644 tests/fast/backends/training_utils/test_metric_domains.py diff --git a/scripts/models/xllm-8B.sh b/scripts/models/xllm-8B.sh deleted file mode 100644 index ba59296724..0000000000 --- a/scripts/models/xllm-8B.sh +++ /dev/null @@ -1,20 +0,0 @@ -# xLLM 8B dense GQA model arguments. -MODEL_ARGS=( - --swiglu - --num-layers 36 - --hidden-size 4096 - --ffn-hidden-size 12288 - --num-attention-heads 32 - --group-query-attention - --num-query-groups 8 - --kv-channels 128 - --disable-bias-linear - --normalization RMSNorm - --norm-epsilon 1e-6 - --layernorm-num-groups 4 - --position-embedding-type rope - --rotary-percent 1.0 - --rotary-base 10000000 - --untie-embeddings-and-output-weights - --vocab-size 250624 -) diff --git a/tests/fast/backends/training_utils/test_metric_domains.py b/tests/fast/backends/training_utils/test_metric_domains.py deleted file mode 100644 index 79f7aea67b..0000000000 --- a/tests/fast/backends/training_utils/test_metric_domains.py +++ /dev/null @@ -1,245 +0,0 @@ -"""Tests for per-domain metric fan-out and unified correctness signal. - -These exercise the math/logic added when upstreaming OPD-specific metrics -into miles. The goal is to verify: - -1. The masked-loss-mask trick used for per-domain reductions: when we zero - out non-target samples' loss_masks, the resulting `sum_of_sample_mean` - reducer produces the same value as `sum_of_sample_mean` over only the - target-domain samples. - -2. Per-domain reductions partition the global reduction. When domains - partition the batch and every sample has the same per-token weight - distribution, the (count-weighted) sum across domains matches the - global reduction. - -3. The `_correctness(s, args)` helper unifies scalar GRPO reward sign and - non-scalar OPD `metadata["correctness_reward"]` into one path. - -4. `compute_samples_seen` returns the cumulative per-rollout sample count. -""" -from __future__ import annotations - -import sys -import types -from argparse import Namespace -from dataclasses import dataclass, field -from typing import Any - -import torch - -from miles.backends.training_utils.cp_utils import get_sum_of_sample_mean -from miles.utils.metric_utils import compute_samples_seen - - -# --------------------------------------------------------------------------- -# Stub miles.backends.training_utils.parallel.get_parallel_state so cp_utils' -# sum_of_sample_mean (which calls it) returns CP=1. -# --------------------------------------------------------------------------- - -class _FakePG: - size = 1 - rank = 0 - - -class _FakeParallelState: - cp = _FakePG() - - -def _patch_parallel_state(monkeypatch): - import miles.backends.training_utils.cp_utils as cp_utils - monkeypatch.setattr(cp_utils, "get_parallel_state", lambda: _FakeParallelState()) - - -# --------------------------------------------------------------------------- -# 1. Masked-loss-mask trick: per-domain reducer ignores non-target samples -# --------------------------------------------------------------------------- - -def test_domain_filtered_reducer_matches_per_domain_subset(monkeypatch): - _patch_parallel_state(monkeypatch) - - # 3 samples: math, code, math. Each sample has 4 response tokens. - # Per-token "values" tensor `x` is the per-sample value broadcast over tokens. - domains = ["math", "code", "math"] - response_lengths = [4, 4, 4] - total_lengths = [4, 4, 4] - loss_masks = [torch.ones(4) for _ in range(3)] - - # x: per-sample mean is [1.0, 2.0, 3.0] - x = torch.tensor([1.0]*4 + [2.0]*4 + [3.0]*4) - - # Build a "math"-filtered reducer - masked_for_math = [ - lm if d == "math" else torch.zeros_like(lm) - for d, lm in zip(domains, loss_masks) - ] - math_reducer = get_sum_of_sample_mean( - total_lengths, response_lengths, masked_for_math, - calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, - ) - - # sum_of_sample_mean: per-sample token-mean, then sum across samples. - # For math: sample 0 contributes 1.0, sample 2 contributes 3.0, sample 1 - # contributes 0 (its mask is all zeros, clamp_min returns 1 in denominator - # but numerator is also 0). Result should be 1.0 + 0 + 3.0 = 4.0. - result = math_reducer(x).item() - assert result == 4.0, f"expected 4.0, got {result}" - - # Same for "code": only sample 1 contributes, value 2.0 - masked_for_code = [ - lm if d == "code" else torch.zeros_like(lm) - for d, lm in zip(domains, loss_masks) - ] - code_reducer = get_sum_of_sample_mean( - total_lengths, response_lengths, masked_for_code, - calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, - ) - result = code_reducer(x).item() - assert result == 2.0, f"expected 2.0, got {result}" - - -def test_domain_reducer_returns_zero_for_absent_domain(monkeypatch): - """When no sample matches the target domain, the reducer must return 0 - (not NaN, not error). aggregate_train_losses requires every microbatch - emit the same key set — a domain with zero samples in this microbatch - must contribute 0 to the positional aggregation.""" - _patch_parallel_state(monkeypatch) - - domains = ["math", "math"] # no "code" samples - response_lengths = [3, 3] - total_lengths = [3, 3] - loss_masks = [torch.ones(3), torch.ones(3)] - x = torch.tensor([5.0, 5.0, 5.0, 7.0, 7.0, 7.0]) - - masked_for_code = [torch.zeros_like(lm) for lm in loss_masks] # all zero - code_reducer = get_sum_of_sample_mean( - total_lengths, response_lengths, masked_for_code, - calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, - ) - result = code_reducer(x).item() - assert result == 0.0, f"expected 0 for absent domain, got {result}" - assert not torch.isnan(torch.tensor(result)) - - -def test_per_domain_reductions_sum_to_global(monkeypatch): - """When domains partition the batch and we use sum-mode (sum_of_sample_mean - sums per-sample means), summing per-domain reductions equals the global one.""" - _patch_parallel_state(monkeypatch) - - domains = ["math", "code", "math", "code"] - response_lengths = [2, 2, 2, 2] - total_lengths = [2, 2, 2, 2] - loss_masks = [torch.ones(2) for _ in range(4)] - # per-sample means: [1, 2, 3, 4] - x = torch.tensor([1.0]*2 + [2.0]*2 + [3.0]*2 + [4.0]*2) - - global_reducer = get_sum_of_sample_mean( - total_lengths, response_lengths, loss_masks, - calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, - ) - global_value = global_reducer(x).item() - assert global_value == 1.0 + 2.0 + 3.0 + 4.0 # = 10.0 - - # Sum per-domain reductions. {math}=samples 0,2 -> 1+3=4; {code}=1,3 -> 2+4=6. - per_domain = 0.0 - for target in ["math", "code"]: - masked = [lm if d == target else torch.zeros_like(lm) for d, lm in zip(domains, loss_masks)] - red = get_sum_of_sample_mean( - total_lengths, response_lengths, masked, - calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, - ) - per_domain += red(x).item() - - assert per_domain == global_value, f"per-domain sum {per_domain} != global {global_value}" - - -# --------------------------------------------------------------------------- -# 2. Unified correctness signal: scalar fallback + metadata override -# --------------------------------------------------------------------------- -# -# The actual `_correctness` helper lives in miles.ray.rollout but importing -# that module pulls in ray. Vendor the implementation here as a fixture and -# verify the *semantic* contract — the implementation in rollout.py is a -# verbatim copy of this snippet (kept in sync by code review). - - -def _vendored_correctness(sample, args) -> bool: - """Mirror of miles.ray.rollout._correctness — keep in sync.""" - if "correctness_reward" in sample.metadata: - return sample.metadata["correctness_reward"] > 0 - val = sample.get_reward_value(args) - return isinstance(val, (int, float)) and val > 0 - - -@dataclass -class _FakeSample: - reward: Any - metadata: dict = field(default_factory=dict) - - def get_reward_value(self, args): - return self.reward if not args.reward_key else self.reward[args.reward_key] - - -def test_correctness_scalar_fallback(): - args = Namespace(reward_key=None) - assert _vendored_correctness(_FakeSample(1.0), args) is True - assert _vendored_correctness(_FakeSample(0.5), args) is True - assert _vendored_correctness(_FakeSample(0.0), args) is False - assert _vendored_correctness(_FakeSample(-0.3), args) is False - - -def test_correctness_metadata_override_takes_precedence(): - """metadata['correctness_reward'] wins even when reward is also scalar.""" - args = Namespace(reward_key=None) - s = _FakeSample(1.0, metadata={"correctness_reward": 0.0}) # scalar says correct, metadata says wrong - assert _vendored_correctness(s, args) is False - s = _FakeSample(0.0, metadata={"correctness_reward": 1.0}) - assert _vendored_correctness(s, args) is True - - -def test_correctness_non_scalar_reward_with_metadata(): - """OPD path: reward is a dict, correctness comes from metadata.""" - args = Namespace(reward_key=None) - s = _FakeSample({"kl_a": 0.5, "kl_b": 0.3}, metadata={"correctness_reward": 1.0}) - assert _vendored_correctness(s, args) is True - s = _FakeSample({"kl_a": 0.5}, metadata={"correctness_reward": 0.0}) - assert _vendored_correctness(s, args) is False - - -def test_correctness_non_scalar_reward_no_metadata_returns_false(): - """Without correctness_reward metadata and a non-scalar reward, the - helper returns False (val>0 short-circuits via isinstance check).""" - args = Namespace(reward_key=None) - s = _FakeSample({"kl_a": 0.5}) - assert _vendored_correctness(s, args) is False - - -# --------------------------------------------------------------------------- -# 3. compute_samples_seen -# --------------------------------------------------------------------------- - -def test_compute_samples_seen_first_rollout(): - args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) - # rollout_id=0 means the first rollout has finished -> 32 samples seen. - assert compute_samples_seen(args, 0) == 32 - - -def test_compute_samples_seen_monotone(): - args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) - seen = [compute_samples_seen(args, i) for i in range(5)] - # Strictly monotone, increment of 32 per rollout. - assert seen == [32, 64, 96, 128, 160] - - -# --------------------------------------------------------------------------- -# 4. Activation-by-presence: domains list is sorted-unique (matches the -# DataIterator._all_domains_cache construction). -# --------------------------------------------------------------------------- - -def test_all_domains_cache_construction(): - """Mirror the construction in get_batch: sorted({d for d in domains if d}).""" - domains = ["math", "code", None, "math", "code", "science", None] - all_domains = sorted({d for d in domains if d}) - assert all_domains == ["code", "math", "science"] - # None values are filtered out (samples without a domain don't add a key). - assert None not in all_domains From 6c27a8e78e7f98c1cc51e041b1cdb8c8802aa7a8 Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Thu, 14 May 2026 21:12:49 +0000 Subject: [PATCH 4/8] added --- miles/backends/training_utils/log_utils.py | 2 ++ miles/ray/rollout.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 120560cd41..c0b778e3f4 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -117,6 +117,7 @@ def gather_log_data( # Calculate step once to avoid duplication step = compute_rollout_step(args, rollout_id) reduced_log_dict["rollout/step"] = step + reduced_log_dict["rollout_step"] = step reduced_log_dict["train/rollout_id"] = rollout_id tracking_utils.log(args, reduced_log_dict, step_key="rollout/step") @@ -273,6 +274,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc if top_level: step = compute_rollout_step(args, rollout_id) top_level["rollout/step"] = step + top_level["rollout_step"] = step tracking_utils.log(args, top_level, step_key="rollout/step") if args.log_multi_turn: diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index e6e83e2bd3..7d968f0ed8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1445,7 +1445,6 @@ def compute_metrics_from_samples(args, samples): log_dict["reward/correctness"] = len(correct) / n for label, grp in [("correct", correct), ("incorrect", incorrect)]: if grp: - log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{label}") # per-category and combined (only if category data present) @@ -1462,7 +1461,6 @@ def compute_metrics_from_samples(args, samples): log_dict[f"reward/{cat}/correctness"] = len(cat_correct) / len(cat_grp) for label, grp in [("correct", cat_correct), ("incorrect", cat_incorrect)]: if grp: - log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}") return log_dict From 3750cbb77ac562d4b3d628d81805574af2bda9f8 Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Mon, 18 May 2026 16:43:55 +0000 Subject: [PATCH 5/8] fix: add step_metric to bare counter and mirrored reward/response_stats define_metric calls --- miles/utils/wandb_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index f9524c91af..86062a9759 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -177,6 +177,10 @@ def _init_wandb_common(): wandb.define_metric("train/rollout_id") # Bare step counters — co-logged so one panel can plot all three as time series # (useful for spotting non-monotone train/step jumps under dynamic batching). - wandb.define_metric("samples_seen") - wandb.define_metric("train_step") - wandb.define_metric("rollout_step") + wandb.define_metric("samples_seen", step_metric="rollout/step") + wandb.define_metric("train_step", step_metric="train/step") + wandb.define_metric("rollout_step", step_metric="rollout/step") + # Bare reward/response_stats mirrors — stripped of the rollout/ prefix by + # _log_rollout_data so the panels appear at the top level in W&B. + wandb.define_metric("reward/*", step_metric="rollout/step") + wandb.define_metric("response_stats/*", step_metric="rollout/step") From 6f68cea89e6fc21d737e04ef6aa82a1535f84d25 Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Mon, 18 May 2026 17:07:46 +0000 Subject: [PATCH 6/8] fix: add step_metric to define_metric calls and emit rollout_step only once per rollout --- miles/backends/training_utils/log_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index c0b778e3f4..423d38c960 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -117,7 +117,8 @@ def gather_log_data( # Calculate step once to avoid duplication step = compute_rollout_step(args, rollout_id) reduced_log_dict["rollout/step"] = step - reduced_log_dict["rollout_step"] = step + if metric_name == "rollout": + reduced_log_dict["rollout_step"] = step reduced_log_dict["train/rollout_id"] = rollout_id tracking_utils.log(args, reduced_log_dict, step_key="rollout/step") @@ -274,7 +275,6 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc if top_level: step = compute_rollout_step(args, rollout_id) top_level["rollout/step"] = step - top_level["rollout_step"] = step tracking_utils.log(args, top_level, step_key="rollout/step") if args.log_multi_turn: From 506769d0d6653af8d8c1b99fceaefc45cb758187 Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Mon, 18 May 2026 23:11:12 +0000 Subject: [PATCH 7/8] fixed resume --- .../experimental/fsdp_utils/checkpoint.py | 10 ++++++ miles/backends/megatron_utils/model.py | 14 +++++++- miles/backends/training_utils/log_utils.py | 35 ++++++++++++++++++- 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/miles/backends/experimental/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py index 6daf7f982c..ec19958f65 100644 --- a/miles/backends/experimental/fsdp_utils/checkpoint.py +++ b/miles/backends/experimental/fsdp_utils/checkpoint.py @@ -12,6 +12,12 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.training_utils.log_utils import ( + init_train_step_counter, + load_train_step_counter, + save_train_step_counter, +) + logger = logging.getLogger(__name__) @@ -183,6 +189,9 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None if getattr(actor.args, "start_rollout_id", None) is None: actor.args.start_rollout_id = iteration + if dist.get_rank() == 0: + init_train_step_counter(load_train_step_counter(actor.args.load, iteration)) + torch.cuda.synchronize() dist.barrier() @@ -245,6 +254,7 @@ def save(actor: Any, iteration: int) -> None: tracker_file = base_dir / "latest_checkpointed_iteration.txt" tracker_file.write_text(str(step_id)) + save_train_step_counter(actor.args.save, step_id) logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}") dist.barrier() diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 56ec82ff51..a42234593d 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -26,7 +26,14 @@ from ..training_utils.ci_utils import check_grad_norm, check_kl from ..training_utils.data import DataIterator, get_batch -from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step +from ..training_utils.log_utils import ( + aggregate_forward_results, + aggregate_train_losses, + init_train_step_counter, + load_train_step_counter, + log_train_step, + save_train_step_counter, +) from ..training_utils.loss import loss_function from ..training_utils.parallel import get_parallel_state from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora @@ -726,6 +733,8 @@ def save( ) clear_memory() + if is_megatron_main_rank(): + save_train_step_counter(args.save, iteration) if hashes is not None: save_model_hashes(args, model, iteration, hashes) @@ -828,4 +837,7 @@ def initialize_model_and_optimizer( opt_param_scheduler.step(increment=iteration * args.global_batch_size) + if is_megatron_main_rank(): + init_train_step_counter(load_train_step_counter(args.load, iteration)) + return model, optimizer, opt_param_scheduler, iteration diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 423d38c960..a0b48b1a26 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -1,6 +1,7 @@ import logging from argparse import Namespace from math import isclose +from pathlib import Path import numpy as np import psutil @@ -55,8 +56,40 @@ # `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when # `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since # each rollout uses its own current num_steps_per_rollout as a scaling factor. -# A simple monotone counter is invariant to that jitter. +# A simple monotone counter is invariant to that jitter. Persisted to a sidecar +# file next to the checkpoint so it survives process restart (otherwise train/step +# would dip to 0 in wandb on resume). _TRAIN_STEP_COUNTER = 0 +_TRAIN_STEP_COUNTER_FILENAME = "train_step_counter.txt" + + +def _iter_dir(checkpoint_dir: str, iteration: int) -> "Path": + return Path(checkpoint_dir) / f"iter_{int(iteration):07d}" + + +def init_train_step_counter(value: int) -> None: + global _TRAIN_STEP_COUNTER + _TRAIN_STEP_COUNTER = int(value) + + +def load_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> int: + if not checkpoint_dir or iteration is None: + return 0 + try: + return int((_iter_dir(checkpoint_dir, iteration) / _TRAIN_STEP_COUNTER_FILENAME).read_text().strip()) + except (OSError, ValueError): + return 0 + + +def save_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None: + if not checkpoint_dir or iteration is None: + return + try: + path = _iter_dir(checkpoint_dir, iteration) + path.mkdir(parents=True, exist_ok=True) + (path / _TRAIN_STEP_COUNTER_FILENAME).write_text(str(_TRAIN_STEP_COUNTER)) + except OSError as e: + logger.warning(f"Failed to persist train-step counter to {checkpoint_dir}: {e}") def gather_log_data( From f205308aa1d46d64dae8fbb68d45e8e5396d5aa5 Mon Sep 17 00:00:00 2001 From: Matthew Yang Date: Mon, 18 May 2026 23:39:25 +0000 Subject: [PATCH 8/8] simplified code --- .../experimental/fsdp_utils/checkpoint.py | 3 +- miles/backends/megatron_utils/model.py | 3 +- miles/backends/training_utils/log_utils.py | 38 ++++++------------- 3 files changed, 13 insertions(+), 31 deletions(-) diff --git a/miles/backends/experimental/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py index ec19958f65..8a0b5ff5d3 100644 --- a/miles/backends/experimental/fsdp_utils/checkpoint.py +++ b/miles/backends/experimental/fsdp_utils/checkpoint.py @@ -14,7 +14,6 @@ from miles.backends.training_utils.log_utils import ( init_train_step_counter, - load_train_step_counter, save_train_step_counter, ) @@ -190,7 +189,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None actor.args.start_rollout_id = iteration if dist.get_rank() == 0: - init_train_step_counter(load_train_step_counter(actor.args.load, iteration)) + init_train_step_counter(actor.args.load, iteration) torch.cuda.synchronize() dist.barrier() diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index a42234593d..3c50770aa0 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -30,7 +30,6 @@ aggregate_forward_results, aggregate_train_losses, init_train_step_counter, - load_train_step_counter, log_train_step, save_train_step_counter, ) @@ -838,6 +837,6 @@ def initialize_model_and_optimizer( opt_param_scheduler.step(increment=iteration * args.global_batch_size) if is_megatron_main_rank(): - init_train_step_counter(load_train_step_counter(args.load, iteration)) + init_train_step_counter(args.load, iteration) return model, optimizer, opt_param_scheduler, iteration diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index a0b48b1a26..2b5e190e34 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -60,36 +60,28 @@ # file next to the checkpoint so it survives process restart (otherwise train/step # would dip to 0 in wandb on resume). _TRAIN_STEP_COUNTER = 0 -_TRAIN_STEP_COUNTER_FILENAME = "train_step_counter.txt" -def _iter_dir(checkpoint_dir: str, iteration: int) -> "Path": - return Path(checkpoint_dir) / f"iter_{int(iteration):07d}" - - -def init_train_step_counter(value: int) -> None: +def init_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None: + """Restore the counter from a checkpoint sidecar; leaves it at 0 if absent or corrupt.""" global _TRAIN_STEP_COUNTER - _TRAIN_STEP_COUNTER = int(value) - - -def load_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> int: - if not checkpoint_dir or iteration is None: - return 0 + if checkpoint_dir is None or iteration is None: + return + path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt" try: - return int((_iter_dir(checkpoint_dir, iteration) / _TRAIN_STEP_COUNTER_FILENAME).read_text().strip()) + _TRAIN_STEP_COUNTER = int(path.read_text().strip()) except (OSError, ValueError): - return 0 + pass def save_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None: - if not checkpoint_dir or iteration is None: + if checkpoint_dir is None or iteration is None: return + path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt" try: - path = _iter_dir(checkpoint_dir, iteration) - path.mkdir(parents=True, exist_ok=True) - (path / _TRAIN_STEP_COUNTER_FILENAME).write_text(str(_TRAIN_STEP_COUNTER)) + path.write_text(str(_TRAIN_STEP_COUNTER)) except OSError as e: - logger.warning(f"Failed to persist train-step counter to {checkpoint_dir}: {e}") + logger.warning(f"Failed to persist train-step counter: {e}") def gather_log_data( @@ -117,15 +109,7 @@ def gather_log_data( # dict to the union of keys with NaN so every rank sends the same shape. # Cost is one all_gather_object on a tiny key list. all_keys: list = [None] * dp_size - logger.info( - f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " - f"rollout={rollout_id} entering all_gather_object, keys={len(log_dict)}" - ) dist.all_gather_object(all_keys, sorted(log_dict.keys()), group=pg.gloo_group) - logger.info( - f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " - f"rollout={rollout_id} all_gather_object returned" - ) union_keys: set = set() for ks in all_keys: if ks: