Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions miles/backends/experimental/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

from miles.backends.training_utils.log_utils import (
init_train_step_counter,
save_train_step_counter,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -183,6 +188,9 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None
if getattr(actor.args, "start_rollout_id", None) is None:
actor.args.start_rollout_id = iteration

if dist.get_rank() == 0:
init_train_step_counter(actor.args.load, iteration)

torch.cuda.synchronize()
dist.barrier()

Expand Down Expand Up @@ -245,6 +253,7 @@ def save(actor: Any, iteration: int) -> None:

tracker_file = base_dir / "latest_checkpointed_iteration.txt"
tracker_file.write_text(str(step_id))
save_train_step_counter(actor.args.save, step_id)
logger.info(f"[FSDP] Saved checkpoint to {checkpoint_dir}")

dist.barrier()
13 changes: 12 additions & 1 deletion miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@

from ..training_utils.ci_utils import check_grad_norm, check_kl
from ..training_utils.data import DataIterator, get_batch
from ..training_utils.log_utils import aggregate_forward_results, aggregate_train_losses, log_train_step
from ..training_utils.log_utils import (
aggregate_forward_results,
aggregate_train_losses,
init_train_step_counter,
log_train_step,
save_train_step_counter,
)
from ..training_utils.loss import loss_function
from ..training_utils.parallel import get_parallel_state
from .checkpoint import load_checkpoint, save_checkpoint, save_checkpoint_with_lora
Expand Down Expand Up @@ -726,6 +732,8 @@ def save(
)

clear_memory()
if is_megatron_main_rank():
save_train_step_counter(args.save, iteration)

if hashes is not None:
save_model_hashes(args, model, iteration, hashes)
Expand Down Expand Up @@ -828,4 +836,7 @@ def initialize_model_and_optimizer(

opt_param_scheduler.step(increment=iteration * args.global_batch_size)

if is_megatron_main_rank():
init_train_step_counter(args.load, iteration)

return model, optimizer, opt_param_scheduler, iteration
12 changes: 12 additions & 0 deletions miles/backends/training_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,23 @@ def get_batch(
parallel_state = get_parallel_state()

assert "tokens" in keys
has_domains = "domains" in data_iterator.rollout_data
if has_domains and "domains" not in keys:
keys = list(keys) + ["domains"]
batch = data_iterator.get_next(keys)

if "dynamic_global_batch_size" in data_iterator.rollout_data:
batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"]

# Canonical domain set, cached on the iterator so every microbatch emits the
# same list (aggregate_train_losses keys positionally on the first microbatch).
if has_domains:
if not hasattr(data_iterator, "_all_domains_cache"):
data_iterator._all_domains_cache = sorted(
{d for d in data_iterator.rollout_data["domains"] if d}
)
batch["all_domains"] = data_iterator._all_domains_cache

tokens = batch["tokens"]
# use 0 as the pad token id should be fine?
pad_token_id = 0
Expand Down
51 changes: 46 additions & 5 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from argparse import Namespace
from math import isclose
from pathlib import Path

import numpy as np
import psutil
Expand Down Expand Up @@ -51,6 +52,37 @@
"returns": "reward",
}

# Cumulative train-step counter across all rollouts. The previous formula
# `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when
# `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since
# each rollout uses its own current num_steps_per_rollout as a scaling factor.
# A simple monotone counter is invariant to that jitter. Persisted to a sidecar
# file next to the checkpoint so it survives process restart (otherwise train/step
# would dip to 0 in wandb on resume).
_TRAIN_STEP_COUNTER = 0


def init_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None:
"""Restore the counter from a checkpoint sidecar; leaves it at 0 if absent or corrupt."""
global _TRAIN_STEP_COUNTER
if checkpoint_dir is None or iteration is None:
return
path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt"
try:
_TRAIN_STEP_COUNTER = int(path.read_text().strip())
except (OSError, ValueError):
pass


def save_train_step_counter(checkpoint_dir: str | None, iteration: int | None) -> None:
if checkpoint_dir is None or iteration is None:
return
path = Path(checkpoint_dir) / f"iter_{int(iteration):07d}" / "train_step_counter.txt"
try:
path.write_text(str(_TRAIN_STEP_COUNTER))
except OSError as e:
logger.warning(f"Failed to persist train-step counter: {e}")


def gather_log_data(
metric_name: str,
Expand Down Expand Up @@ -102,6 +134,8 @@ def gather_log_data(
# Calculate step once to avoid duplication
step = compute_rollout_step(args, rollout_id)
reduced_log_dict["rollout/step"] = step
if metric_name == "rollout":
reduced_log_dict["rollout_step"] = step
reduced_log_dict["train/rollout_id"] = rollout_id
tracking_utils.log(args, reduced_log_dict, step_key="rollout/step")

Expand Down Expand Up @@ -170,6 +204,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
"dynamic_global_batch_size",
"weight_versions",
"metadata",
"domains",
]:
continue
# Upload per sample mean for each rollout value
Expand Down Expand Up @@ -492,7 +527,9 @@ def log_train_step(
Returns:
The formatted log_dict (for CI tests or other uses).
"""
accumulated_step_id = rollout_id * num_steps_per_rollout + step_id
global _TRAIN_STEP_COUNTER
accumulated_step_id = _TRAIN_STEP_COUNTER
_TRAIN_STEP_COUNTER += 1
role_tag = "" if role == "actor" else f"{role}-"

log_dict_out = {
Expand Down Expand Up @@ -525,7 +562,7 @@ def log_train_step(
# cross-plotted against rollout-side axes in the wandb UI.
log_dict_out["train/rollout_id"] = rollout_id
log_dict_out["train/step_in_rollout"] = step_id
log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id)
log_dict_out["train_step"] = accumulated_step_id

# Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged)
grouped_additions = {}
Expand All @@ -534,9 +571,13 @@ def log_train_step(
if not full_key.startswith(prefix):
continue
bare_key = full_key[len(prefix):]
if bare_key in _TRAIN_METRIC_GROUPS:
for group in _TRAIN_METRIC_GROUPS[bare_key]:
grouped_additions[f"{group}/{bare_key}"] = val
# Per-domain keys arrive as "<metric>/<domain>" — route to "<group>/<domain>/<metric>".
metric_name, sep, domain = bare_key.rpartition("/")
lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key
if lookup in _TRAIN_METRIC_GROUPS:
suffix = f"{domain}/{metric_name}" if lookup == metric_name else bare_key
for group in _TRAIN_METRIC_GROUPS[lookup]:
grouped_additions[f"{group}/{suffix}"] = val
elif bare_key.startswith("lr-pg_"):
grouped_additions[f"optimization/{bare_key}"] = val
log_dict_out.update(grouped_additions)
Expand Down
55 changes: 52 additions & 3 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ def policy_loss_function(
else:
pg_loss_reducer = sum_of_sample_mean

# Saved for per-domain fan-out (reducers below overwrite these names with scalars).
_pg_loss_per_token = pg_loss
_pg_clipfrac_per_token = pg_clipfrac
_ppo_kl_per_token = ppo_kl

pg_loss = pg_loss_reducer(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
ppo_kl = sum_of_sample_mean(ppo_kl)
Expand Down Expand Up @@ -691,18 +696,24 @@ def policy_loss_function(
# Train-inference mismatch: compare inference engine vs FSDP at rollout time
train_rollout_logprob_abs_diff = None
train_rollout_logprob_diff = None
_train_rollout_logprob_abs_per_token = None
_train_rollout_logprob_signed_per_token = None
if "rollout_log_probs" in batch and batch["rollout_log_probs"]:
rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0)
log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0)
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
_train_rollout_logprob_abs_per_token = (old_log_probs - rollout_log_probs_cat).abs()
# signed: log π(inf) − log π(fsdp rollout)
train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach()
_train_rollout_logprob_signed_per_token = rollout_log_probs_cat - log_probs_batch_cat
train_rollout_logprob_abs_diff = sum_of_sample_mean(_train_rollout_logprob_abs_per_token).clone().detach()
train_rollout_logprob_diff = sum_of_sample_mean(_train_rollout_logprob_signed_per_token).clone().detach()

# KL vs reference model — always log when ref present, regardless of use_kl_loss
ref_kl_metric = None
_ref_kl_per_token = None
if "ref_log_probs" in batch and batch["ref_log_probs"]:
ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0)
ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach()
_ref_kl_per_token = log_probs - ref_log_probs_cat
ref_kl_metric = sum_of_sample_mean(_ref_kl_per_token).clone().detach()

reported_loss = {
"loss": loss.clone().detach(),
Expand Down Expand Up @@ -736,6 +747,44 @@ def policy_loss_function(
if args.use_opsm:
reported_loss["opsm_clipfrac"] = opsm_clipfrac

# Per-domain fan-out: activated by batch["domains"] (set when samples carry
# metadata["domain"]). batch["all_domains"] is cached on DataIterator so every
# microbatch emits the same key set (aggregate_train_losses keys positionally).
# grad_norm isn't split: backward() has already mixed gradients.
if batch.get("domains") and batch.get("all_domains"):
per_token = {
"log_probs": log_probs,
"old_log_probs": old_log_probs,
"pg_loss": _pg_loss_per_token,
"pg_clipfrac": _pg_clipfrac_per_token,
"ppo_kl": _ppo_kl_per_token,
"entropy_loss": entropy,
}
if _ref_kl_per_token is not None:
per_token["ref_kl"] = _ref_kl_per_token
if _train_rollout_logprob_signed_per_token is not None:
per_token["train_rollout_logprob_diff"] = _train_rollout_logprob_signed_per_token
per_token["train_rollout_logprob_abs_diff"] = _train_rollout_logprob_abs_per_token
if args.get_mismatch_metrics or args.use_tis:
per_token["ois"] = ois
per_token.update(tis_metrics)

for d in batch["all_domains"]:
masked = [
lm if dd == d else torch.zeros_like(lm)
for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False)
]
reducer = get_sum_of_sample_mean(
total_lengths, response_lengths, masked,
args.calculate_per_token_loss, args.qkv_format, max_seq_lens,
loss_agg_mode=getattr(args, "loss_agg_mode", None),
)
for name, t in per_token.items():
reported_loss[f"{name}/{d}"] = reducer(t).clone().detach()
reported_loss[f"loss/{d}"] = (
reported_loss[f"pg_loss/{d}"] - args.entropy_coef * reported_loss[f"entropy_loss/{d}"]
)

return loss, reported_loss


Expand Down
Loading
Loading