diff --git a/miles/backends/experimental/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py index 6daf7f982c..8a0b5ff5d3 100644 --- a/miles/backends/experimental/fsdp_utils/checkpoint.py +++ b/miles/backends/experimental/fsdp_utils/checkpoint.py @@ -12,6 +12,11 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.training_utils.log_utils import ( + init_train_step_counter, + save_train_step_counter, +) + logger = logging.getLogger(__name__) @@ -183,6 +188,9 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None if getattr(actor.args, "start_rollout_id", None) is None: actor.args.start_rollout_id = iteration + if dist.get_rank() == 0: + init_train_step_counter(actor.args.load, iteration) + torch.cuda.synchronize() dist.barrier() @@ -245,6 +253,7 @@ def save(actor: Any, iteration: int) -> None: tracker_file = base_dir / "latest_checkpointed_iteration.txt" tracker_file.write_text(str(step_id)) + save_train_step_counter(actor.args.save, step_id) logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}") dist.barrier() diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 56ec82ff51..3c50770aa0 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -26,7 +26,13 @@ from ..training_utils.ci_utils import check_grad_norm, check_kl from ..training_utils.data import DataIterator, get_batch -from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step +from ..training_utils.log_utils import ( + aggregate_forward_results, + aggregate_train_losses, + init_train_step_counter, + log_train_step, + save_train_step_counter, +) from ..training_utils.loss import loss_function from ..training_utils.parallel import get_parallel_state from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora @@ -726,6 +732,8 @@ def save( ) clear_memory() + if is_megatron_main_rank(): + save_train_step_counter(args.save, iteration) if hashes is not None: save_model_hashes(args, model, iteration, hashes) @@ -828,4 +836,7 @@ def initialize_model_and_optimizer( opt_param_scheduler.step(increment=iteration * args.global_batch_size) + if is_megatron_main_rank(): + init_train_step_counter(args.load, iteration) + return model, optimizer, opt_param_scheduler, iteration diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 3bda887b26..2733c93442 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -115,11 +115,23 @@ def get_batch( parallel_state = get_parallel_state() assert "tokens" in keys + has_domains = "domains" in data_iterator.rollout_data + if has_domains and "domains" not in keys: + keys = list(keys) + ["domains"] batch = data_iterator.get_next(keys) if "dynamic_global_batch_size" in data_iterator.rollout_data: batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"] + # Canonical domain set, cached on the iterator so every microbatch emits the + # same list (aggregate_train_losses keys positionally on the first microbatch). + if has_domains: + if not hasattr(data_iterator, "_all_domains_cache"): + data_iterator._all_domains_cache = sorted( + {d for d in data_iterator.rollout_data["domains"] if d} + ) + batch["all_domains"] = data_iterator._all_domains_cache + tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 4096a4f8ac..2b5e190e34 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -1,6 +1,7 @@ import logging from argparse import Namespace from math import isclose +from pathlib import Path import numpy as np import psutil @@ -51,6 +52,37 @@ "returns": "reward", } +# Cumulative train-step counter across all rollouts. The previous formula +# `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when +# `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since +# each rollout uses its own current num_steps_per_rollout as a scaling factor. +# A simple monotone counter is invariant to that jitter. Persisted to a sidecar +# file next to the checkpoint so it survives process restart (otherwise train/step +# would dip to 0 in wandb on resume). +_TRAIN_STEP_COUNTER = 0 + + +def init_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None: + """Restore the counter from a checkpoint sidecar; leaves it at 0 if absent or corrupt.""" + global _TRAIN_STEP_COUNTER + if checkpoint_dir is None or iteration is None: + return + path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt" + try: + _TRAIN_STEP_COUNTER = int(path.read_text().strip()) + except (OSError, ValueError): + pass + + +def save_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None: + if checkpoint_dir is None or iteration is None: + return + path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt" + try: + path.write_text(str(_TRAIN_STEP_COUNTER)) + except OSError as e: + logger.warning(f"Failed to persist train-step counter: {e}") + def gather_log_data( metric_name: str, @@ -102,6 +134,8 @@ def gather_log_data( # Calculate step once to avoid duplication step = compute_rollout_step(args, rollout_id) reduced_log_dict["rollout/step"] = step + if metric_name == "rollout": + reduced_log_dict["rollout_step"] = step reduced_log_dict["train/rollout_id"] = rollout_id tracking_utils.log(args, reduced_log_dict, step_key="rollout/step") @@ -170,6 +204,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "dynamic_global_batch_size", "weight_versions", "metadata", + "domains", ]: continue # Upload per sample mean for each rollout value @@ -492,7 +527,9 @@ def log_train_step( Returns: The formatted log_dict (for CI tests or other uses). """ - accumulated_step_id = rollout_id * num_steps_per_rollout + step_id + global _TRAIN_STEP_COUNTER + accumulated_step_id = _TRAIN_STEP_COUNTER + _TRAIN_STEP_COUNTER += 1 role_tag = "" if role == "actor" else f"{role}-" log_dict_out = { @@ -525,7 +562,7 @@ def log_train_step( # cross-plotted against rollout-side axes in the wandb UI. log_dict_out["train/rollout_id"] = rollout_id log_dict_out["train/step_in_rollout"] = step_id - log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id) + log_dict_out["train_step"] = accumulated_step_id # Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged) grouped_additions = {} @@ -534,9 +571,13 @@ def log_train_step( if not full_key.startswith(prefix): continue bare_key = full_key[len(prefix):] - if bare_key in _TRAIN_METRIC_GROUPS: - for group in _TRAIN_METRIC_GROUPS[bare_key]: - grouped_additions[f"{group}/{bare_key}"] = val + # Per-domain keys arrive as "/" — route to "//". + metric_name, sep, domain = bare_key.rpartition("/") + lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key + if lookup in _TRAIN_METRIC_GROUPS: + suffix = f"{domain}/{metric_name}" if lookup == metric_name else bare_key + for group in _TRAIN_METRIC_GROUPS[lookup]: + grouped_additions[f"{group}/{suffix}"] = val elif bare_key.startswith("lr-pg_"): grouped_additions[f"optimization/{bare_key}"] = val log_dict_out.update(grouped_additions) diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index bd0386e9f4..317286f6e0 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -653,6 +653,11 @@ def policy_loss_function( else: pg_loss_reducer = sum_of_sample_mean + # Saved for per-domain fan-out (reducers below overwrite these names with scalars). + _pg_loss_per_token = pg_loss + _pg_clipfrac_per_token = pg_clipfrac + _ppo_kl_per_token = ppo_kl + pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) @@ -691,18 +696,24 @@ def policy_loss_function( # Train-inference mismatch: compare inference engine vs FSDP at rollout time train_rollout_logprob_abs_diff = None train_rollout_logprob_diff = None + _train_rollout_logprob_abs_per_token = None + _train_rollout_logprob_signed_per_token = None if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + _train_rollout_logprob_abs_per_token = (old_log_probs - rollout_log_probs_cat).abs() # signed: log π(inf) − log π(fsdp rollout) - train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() + _train_rollout_logprob_signed_per_token = rollout_log_probs_cat - log_probs_batch_cat + train_rollout_logprob_abs_diff = sum_of_sample_mean(_train_rollout_logprob_abs_per_token).clone().detach() + train_rollout_logprob_diff = sum_of_sample_mean(_train_rollout_logprob_signed_per_token).clone().detach() # KL vs reference model — always log when ref present, regardless of use_kl_loss ref_kl_metric = None + _ref_kl_per_token = None if "ref_log_probs" in batch and batch["ref_log_probs"]: ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0) - ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach() + _ref_kl_per_token = log_probs - ref_log_probs_cat + ref_kl_metric = sum_of_sample_mean(_ref_kl_per_token).clone().detach() reported_loss = { "loss": loss.clone().detach(), @@ -736,6 +747,44 @@ def policy_loss_function( if args.use_opsm: reported_loss["opsm_clipfrac"] = opsm_clipfrac + # Per-domain fan-out: activated by batch["domains"] (set when samples carry + # metadata["domain"]). batch["all_domains"] is cached on DataIterator so every + # microbatch emits the same key set (aggregate_train_losses keys positionally). + # grad_norm isn't split: backward() has already mixed gradients. + if batch.get("domains") and batch.get("all_domains"): + per_token = { + "log_probs": log_probs, + "old_log_probs": old_log_probs, + "pg_loss": _pg_loss_per_token, + "pg_clipfrac": _pg_clipfrac_per_token, + "ppo_kl": _ppo_kl_per_token, + "entropy_loss": entropy, + } + if _ref_kl_per_token is not None: + per_token["ref_kl"] = _ref_kl_per_token + if _train_rollout_logprob_signed_per_token is not None: + per_token["train_rollout_logprob_diff"] = _train_rollout_logprob_signed_per_token + per_token["train_rollout_logprob_abs_diff"] = _train_rollout_logprob_abs_per_token + if args.get_mismatch_metrics or args.use_tis: + per_token["ois"] = ois + per_token.update(tis_metrics) + + for d in batch["all_domains"]: + masked = [ + lm if dd == d else torch.zeros_like(lm) + for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False) + ] + reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked, + args.calculate_per_token_loss, args.qkv_format, max_seq_lens, + loss_agg_mode=getattr(args, "loss_agg_mode", None), + ) + for name, t in per_token.items(): + reported_loss[f"{name}/{d}"] = reducer(t).clone().detach() + reported_loss[f"loss/{d}"] = ( + reported_loss[f"pg_loss/{d}"] - args.entropy_coef * reported_loss[f"entropy_loss/{d}"] + ) + return loss, reported_loss diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index ebe44f8fb9..7d968f0ed8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -38,7 +38,13 @@ from miles.utils.iter_utils import group_by from miles.utils.logging_utils import configure_logger from miles.utils.metric_checker import MetricChecker -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix +from miles.utils.metric_utils import ( + compute_pass_rate, + compute_rollout_step, + compute_samples_seen, + compute_statistics, + dict_add_prefix, +) from miles.utils.misc import load_function from miles.utils.ray_utils import Box from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions @@ -752,6 +758,11 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] + # Presence of metadata["domain"] activates per-domain metric fan-out in policy_loss_function. + domains = [s.metadata.get("domain") for s in samples] + if any(domains): + train_data["domains"] = domains + if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] @@ -869,6 +880,7 @@ def _stat(xs): "prompt", "teacher_log_probs", "weight_versions", + "domains", ]: if key not in data: continue @@ -1371,10 +1383,16 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_ log_dict = {**(rollout_extra_metrics or {})} log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") + # Mirror reward/* and response_stats/* as top-level wandb panels. + for full_key, val in list(log_dict.items()): + if full_key.startswith(("rollout/reward/", "rollout/response_stats/")): + log_dict[full_key[len("rollout/"):]] = val logger.info(f"perf {rollout_id}: {log_dict}") step = compute_rollout_step(args, rollout_id) log_dict["rollout/step"] = step log_dict["train/rollout_id"] = rollout_id + log_dict["samples_seen"] = compute_samples_seen(args, rollout_id) + log_dict["rollout_step"] = step tracking_utils.log(args, log_dict, step_key="rollout/step") @@ -1422,11 +1440,11 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward") # per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward) - correct = [s for s in samples if s.get_reward_value(args) > 0] - incorrect = [s for s in samples if s.get_reward_value(args) <= 0] + correct = [s for s in samples if _correctness(s, args)] + incorrect = [s for s in samples if not _correctness(s, args)] + log_dict["reward/correctness"] = len(correct) / n for label, grp in [("correct", correct), ("incorrect", incorrect)]: if grp: - log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{label}") # per-category and combined (only if category data present) @@ -1438,12 +1456,11 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n) log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}") log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}") - for label, grp in [ - ("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]), - ("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]), - ]: + cat_correct = [s for s in cat_grp if _correctness(s, args)] + cat_incorrect = [s for s in cat_grp if not _correctness(s, args)] + log_dict[f"reward/{cat}/correctness"] = len(cat_correct) / len(cat_grp) + for label, grp in [("correct", cat_correct), ("incorrect", cat_incorrect)]: if grp: - log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}") return log_dict @@ -1486,6 +1503,9 @@ def _compute_zero_std_metrics(args, all_samples: list[Sample]): # only compute in GRPO-like algorithms where one prompt has multiple responses if args.advantage_estimator == "ppo": return {} + # Skip non-scalar rewards (round() and zero-std comparison are ill-defined on dicts). + if all_samples and not isinstance(all_samples[0].get_reward_value(args), (int, float)): + return {} def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] @@ -1530,8 +1550,11 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]): return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} -# Candidate metadata keys to auto-detect problem category (checked in order) -_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"] +# Candidate metadata keys to auto-detect problem category (checked in order). +# `domain` is first because it's the routing key in multi-teacher setups; if a sample +# carries both `domain` and one of the legacy fields like `category`, group by domain +# so per-cat correctness panels match the per-domain loss panels. +_CANDIDATE_CATEGORY_KEYS = ["domain", "category", "type", "subject", "problem_type"] def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: @@ -1547,11 +1570,22 @@ def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: return None +def _correctness(sample: Sample, args) -> bool: + """Non-scalar reward fns set metadata["correctness_reward"]; scalars fall back to sign.""" + if "correctness_reward" in sample.metadata: + return sample.metadata["correctness_reward"] > 0 + val = sample.get_reward_value(args) + return isinstance(val, (int, float)) and val > 0 + + def _compute_grouped_reward_metrics( args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True ) -> dict: """Reward/outcome metrics for a split — emitted under reward/ sections.""" - result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()} + result = {} + # Skip raw_reward when reward is non-scalar (e.g. dict-valued OPD rewards). + if group and isinstance(group[0].get_reward_value(args), (int, float)): + result[f"{prefix}/raw_reward"] = np.mean([s.get_reward_value(args) for s in group]).item() if include_count_frac: result[f"{prefix}/count_frac"] = len(group) / n_total return result @@ -1576,8 +1610,8 @@ def _compute_group_outcome_metrics( n_groups = len(groups) if n_groups == 0: return {} - all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g)) - all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g)) + all_correct = sum(1 for g in groups if all(_correctness(s, args) for s in g)) + all_incorrect = sum(1 for g in groups if all(not _correctness(s, args) for s in g)) return { f"{prefix}/all_correct_group_frac": all_correct / n_groups, f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups, diff --git a/miles/utils/metric_utils.py b/miles/utils/metric_utils.py index 66292c79e7..839b50e25f 100644 --- a/miles/utils/metric_utils.py +++ b/miles/utils/metric_utils.py @@ -118,3 +118,8 @@ def compute_rollout_step(args, rollout_id): if args.wandb_always_use_train_step: return rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size return rollout_id + + +def compute_samples_seen(args, rollout_id: int) -> int: + """Cumulative samples through (and including) rollout `rollout_id` (0-indexed).""" + return args.rollout_batch_size * args.n_samples_per_prompt * (rollout_id + 1) diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index c29fcc7eaa..86062a9759 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -175,3 +175,12 @@ def _init_wandb_common(): # rollout counter. Declared after the "train/*" wildcard so the specific name # isn't inadvertently treated as step-metric'd against train/step. wandb.define_metric("train/rollout_id") + # Bare step counters — co-logged so one panel can plot all three as time series + # (useful for spotting non-monotone train/step jumps under dynamic batching). + wandb.define_metric("samples_seen", step_metric="rollout/step") + wandb.define_metric("train_step", step_metric="train/step") + wandb.define_metric("rollout_step", step_metric="rollout/step") + # Bare reward/response_stats mirrors — stripped of the rollout/ prefix by + # _log_rollout_data so the panels appear at the top level in W&B. + wandb.define_metric("reward/*", step_metric="rollout/step") + wandb.define_metric("response_stats/*", step_metric="rollout/step") diff --git a/scripts/models/xllm-8B.sh b/scripts/models/xllm-8B.sh deleted file mode 100644 index ba59296724..0000000000 --- a/scripts/models/xllm-8B.sh +++ /dev/null @@ -1,20 +0,0 @@ -# xLLM 8B dense GQA model arguments. -MODEL_ARGS=( - --swiglu - --num-layers 36 - --hidden-size 4096 - --ffn-hidden-size 12288 - --num-attention-heads 32 - --group-query-attention - --num-query-groups 8 - --kv-channels 128 - --disable-bias-linear - --normalization RMSNorm - --norm-epsilon 1e-6 - --layernorm-num-groups 4 - --position-embedding-type rope - --rotary-percent 1.0 - --rotary-base 10000000 - --untie-embeddings-and-output-weights - --vocab-size 250624 -)