diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 2cfd373f49..d0a0daec75 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import random import socket @@ -46,6 +48,481 @@ logger = logging.getLogger(__name__) +import math +from typing import Any + + + +def validate_rollout_for_grpo_training_step( + args, + rollout_data, + *, + rollout_id=None, + where="train_actor.begin", + logger=None, + require_log_probs=False, +): + """ + Local-only rollout validator for MegatronTrainRayActor. + + No collectives. No mutation. Safe to run independently on every rank. + Logs useful diagnostics before raising so NCCL-desync root cause is visible + in the first failing rank's log. + """ + import math + import socket + import traceback + + import torch + import torch.distributed as dist + + errors = [] + warnings = [] + + def _safe_get_parallel_state(): + try: + # actor.py already imports get_parallel_state from + # miles.backends.training_utils.parallel + return get_parallel_state() + except Exception as e: + warnings.append(f"get_parallel_state() failed: {type(e).__name__}: {e}") + return None + + ps = _safe_get_parallel_state() + + def _maybe_attr(obj, *path, default=None): + cur = obj + for p in path: + try: + cur = getattr(cur, p) + except Exception: + return default + return cur + + def _rank_info(): + parts = [] + try: + parts.append(f"host={socket.gethostname()}") + except Exception: + pass + + try: + if dist.is_available() and dist.is_initialized(): + parts.append(f"global_rank={dist.get_rank()}/{dist.get_world_size()}") + else: + parts.append("global_rank=dist_not_initialized") + except Exception as e: + parts.append(f"global_rank=unavailable:{type(e).__name__}") + + if ps is not None: + for name in ("dp", "intra_dp", "cp", "tp", "pp", "ep"): + group_state = getattr(ps, name, None) + if group_state is not None: + r = getattr(group_state, "rank", "?") + s = getattr(group_state, "size", "?") + parts.append(f"{name}_rank={r}/{s}") + + # Fallbacks from args. These are less authoritative than ps. + for arg_name in ( + "data_parallel_size", + "context_parallel_size", + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "expert_model_parallel_size", + ): + if hasattr(args, arg_name): + parts.append(f"args.{arg_name}={getattr(args, arg_name)}") + + return " ".join(parts) + + rank_info = _rank_info() + + def _log(level, msg): + prefix = f"ROLLOUT_VALIDATE {where}" + if rollout_id is not None: + prefix += f" rollout_id={rollout_id}" + prefix += f" {rank_info} :: {msg}" + + if logger is not None: + getattr(logger, level)(prefix) + else: + print(prefix, flush=True) + + def _add_error(msg): + errors.append(msg) + + def _add_warning(msg): + warnings.append(msg) + + def _is_seq(x): + return isinstance(x, (list, tuple)) + + def _present(key): + return key in rollout_data and rollout_data[key] is not None + + def _numel(x): + if torch.is_tensor(x): + return int(x.numel()) + if isinstance(x, (list, tuple)): + return len(x) + return None + + def _ndim(x): + if torch.is_tensor(x): + return int(x.ndim) + if isinstance(x, (list, tuple)): + return 1 + return None + + def _finite_tensor_or_list(x): + try: + if torch.is_tensor(x): + return bool(torch.isfinite(x.float()).all().item()) + if isinstance(x, (list, tuple)): + return all(math.isfinite(float(v)) for v in x) + if isinstance(x, (float, int)): + return math.isfinite(float(x)) + return False + except Exception: + return False + + def _sum_float(x): + if torch.is_tensor(x): + return float(x.float().sum().item()) + if isinstance(x, (list, tuple)): + return float(sum(float(v) for v in x)) + return float(x) + + def _summarize_vector_list(key, limit=3): + if not _present(key): + return f"{key}=MISSING" + xs = rollout_data[key] + if not _is_seq(xs): + return f"{key}=BAD_TYPE({type(xs).__name__})" + + shapes = [] + dtypes = [] + devices = [] + samples = min(len(xs), limit) + for i in range(samples): + x = xs[i] + if torch.is_tensor(x): + shapes.append(tuple(x.shape)) + dtypes.append(str(x.dtype)) + devices.append(str(x.device)) + elif isinstance(x, (list, tuple)): + shapes.append((len(x),)) + dtypes.append(type(x[0]).__name__ if len(x) else "empty") + devices.append("python") + else: + shapes.append(type(x).__name__) + dtypes.append(type(x).__name__) + devices.append("python") + return ( + f"{key}: len={len(xs)} first_shapes={shapes} " + f"first_dtypes={dtypes} first_devices={devices}" + ) + + def _basic_batch_summary(): + keys = sorted(list(rollout_data.keys())) + key_summary = "keys=" + ",".join(keys) + + lines = [key_summary] + + for key in ( + "rewards", + "response_lengths", + "total_lengths", + "loss_masks", + "tokens", + "input_ids", + "max_seq_lens", + "log_probs", + "rollout_log_probs", + "ref_log_probs", + "values", + "advantages", + "returns", + ): + lines.append(_summarize_vector_list(key)) + + # Numeric aggregate summary. + try: + if _present("response_lengths") and _is_seq(rollout_data["response_lengths"]): + rs = [int(x) for x in rollout_data["response_lengths"]] + lines.append( + f"response_lengths: count={len(rs)} sum={sum(rs)} " + f"min={min(rs) if rs else None} max={max(rs) if rs else None}" + ) + except Exception as e: + lines.append(f"response_lengths aggregate failed: {type(e).__name__}: {e}") + + try: + if _present("total_lengths") and _is_seq(rollout_data["total_lengths"]): + ts = [int(x) for x in rollout_data["total_lengths"]] + lines.append( + f"total_lengths: count={len(ts)} sum={sum(ts)} " + f"min={min(ts) if ts else None} max={max(ts) if ts else None}" + ) + except Exception as e: + lines.append(f"total_lengths aggregate failed: {type(e).__name__}: {e}") + + try: + if _present("loss_masks") and _is_seq(rollout_data["loss_masks"]): + ms = [_sum_float(m) for m in rollout_data["loss_masks"]] + lines.append( + f"loss_mask_sums: count={len(ms)} sum={sum(ms):.1f} " + f"min={min(ms) if ms else None} max={max(ms) if ms else None}" + ) + except Exception as e: + lines.append(f"loss_masks aggregate failed: {type(e).__name__}: {e}") + + try: + if _present("rewards") and _is_seq(rollout_data["rewards"]): + rw = [float(x) for x in rollout_data["rewards"]] + lines.append( + f"rewards: count={len(rw)} sum={sum(rw):.6g} " + f"min={min(rw) if rw else None} max={max(rw) if rw else None}" + ) + except Exception as e: + lines.append(f"rewards aggregate failed: {type(e).__name__}: {e}") + + return " | ".join(lines) + + # ------------------------------------------------------------------ + # Start diagnostics. + # ------------------------------------------------------------------ + + _log( + "warning", + "start " + f"advantage_estimator={getattr(args, 'advantage_estimator', None)} " + f"normalize_advantages={getattr(args, 'normalize_advantages', None)} " + f"use_rollout_logprobs={getattr(args, 'use_rollout_logprobs', None)} " + f"use_critic={getattr(args, 'use_critic', None)} " + f"qkv_format={getattr(args, 'qkv_format', None)} " + f"compute_advantages_and_returns={getattr(args, 'compute_advantages_and_returns', None)} " + f"n_samples_per_prompt={getattr(args, 'n_samples_per_prompt', None)} " + f"generate_multi_samples={getattr(args, 'generate_multi_samples', None)}", + ) + + # ------------------------------------------------------------------ + # Required fields. + # ------------------------------------------------------------------ + + required = ("rewards", "response_lengths", "total_lengths", "loss_masks") + for key in required: + if not _present(key): + _add_error(f"missing required key {key!r}") + elif not _is_seq(rollout_data[key]): + _add_error(f"{key!r} must be list/tuple, got {type(rollout_data[key]).__name__}") + + if errors: + _log("error", "summary_before_failure " + _basic_batch_summary()) + for w in warnings: + _log("warning", w) + for e in errors: + _log("error", "failure " + e) + raise ValueError(f"{where}: rollout validation failed with {len(errors)} error(s); see logs above") + + n = len(rollout_data["rewards"]) + if n == 0: + _add_error("empty rollout batch: len(rewards)=0") + + for key in ("response_lengths", "total_lengths", "loss_masks"): + got = len(rollout_data[key]) + if got != n: + _add_error(f"{key!r} length mismatch: got {got}, expected {n}") + + token_key = None + if _present("tokens"): + token_key = "tokens" + elif _present("input_ids"): + token_key = "input_ids" + + if token_key is None: + _add_warning("neither 'tokens' nor 'input_ids' present; cannot check total_lengths against token tensors") + else: + if not _is_seq(rollout_data[token_key]): + _add_error(f"{token_key!r} must be list/tuple, got {type(rollout_data[token_key]).__name__}") + elif len(rollout_data[token_key]) != n: + _add_error(f"{token_key!r} length mismatch: got {len(rollout_data[token_key])}, expected {n}") + + if errors: + _log("error", "summary_before_failure " + _basic_batch_summary()) + for w in warnings: + _log("warning", w) + for e in errors: + _log("error", "failure " + e) + raise ValueError(f"{where}: rollout validation failed with {len(errors)} error(s); see logs above") + + response_lengths = [] + total_lengths = [] + + for i in range(n): + # reward + try: + r = float(rollout_data["rewards"][i]) + if not math.isfinite(r): + _add_error(f"rewards[{i}] is not finite: {r}") + except Exception as e: + _add_error(f"rewards[{i}] is not float-like: {type(e).__name__}: {e}") + + # lengths + try: + resp = int(rollout_data["response_lengths"][i]) + total = int(rollout_data["total_lengths"][i]) + response_lengths.append(resp) + total_lengths.append(total) + + if resp <= 0: + _add_error(f"response_lengths[{i}] must be > 0, got {resp}") + if total <= 0: + _add_error(f"total_lengths[{i}] must be > 0, got {total}") + if resp > total: + _add_error(f"response_lengths[{i}]={resp} > total_lengths[{i}]={total}") + except Exception as e: + _add_error(f"bad lengths at sample {i}: {type(e).__name__}: {e}") + continue + + # tokens/input_ids + if token_key is not None and _is_seq(rollout_data[token_key]) and i < len(rollout_data[token_key]): + tok = rollout_data[token_key][i] + tok_n = _numel(tok) + if tok_n is None: + _add_error(f"{token_key}[{i}] bad type: {type(tok).__name__}") + elif tok_n != total: + _add_error(f"{token_key}[{i}] length {tok_n} != total_lengths[{i}] {total}") + if torch.is_tensor(tok) and tok.ndim != 1: + _add_error(f"{token_key}[{i}] must be 1D, got shape={tuple(tok.shape)}") + + # masks + mask = rollout_data["loss_masks"][i] + mask_n = _numel(mask) + mask_ndim = _ndim(mask) + + if mask_n is None: + _add_error(f"loss_masks[{i}] bad type: {type(mask).__name__}") + continue + + if mask_ndim != 1: + _add_error(f"loss_masks[{i}] must be 1D, got ndim={mask_ndim}") + + if mask_n != resp: + _add_error(f"loss_masks[{i}] length {mask_n} != response_lengths[{i}] {resp}") + + if not _finite_tensor_or_list(mask): + _add_error(f"loss_masks[{i}] contains NaN/Inf or non-numeric values") + continue + + mask_sum = _sum_float(mask) + if mask_sum <= 0: + _add_error(f"loss_masks[{i}] has no active tokens, sum={mask_sum}, response_len={resp}") + if mask_sum > resp: + # Warning-only: float/weighted masks can legitimately have sum > resp. + _add_warning(f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)") + + if torch.is_tensor(mask): + # Binary check: warning, not fatal, because masks may be float. + try: + is_binary = bool(torch.all((mask == 0) | (mask == 1)).item()) + if not is_binary: + _add_warning(f"loss_masks[{i}] is not binary 0/1") + except Exception as e: + _add_warning(f"binary check failed for loss_masks[{i}]: {type(e).__name__}: {e}") + + # max_seq_lens if present. + if _present("max_seq_lens"): + xs = rollout_data["max_seq_lens"] + if not _is_seq(xs): + _add_error(f"max_seq_lens must be list/tuple, got {type(xs).__name__}") + elif len(xs) != n: + _add_error(f"max_seq_lens length {len(xs)} != expected {n}") + else: + for i, x in enumerate(xs): + try: + msl = int(x) + if msl <= 0: + _add_error(f"max_seq_lens[{i}] must be > 0, got {msl}") + elif i < len(total_lengths) and msl < total_lengths[i]: + _add_error(f"max_seq_lens[{i}]={msl} < total_lengths[{i}]={total_lengths[i]}") + except Exception as e: + _add_error(f"max_seq_lens[{i}] is not int-like: {type(e).__name__}: {e}") + + # Optional per-response vector fields. + def check_vector_list(key, expected_lengths): + if not _present(key): + return + xs = rollout_data[key] + if not _is_seq(xs): + _add_error(f"{key} must be list/tuple, got {type(xs).__name__}") + return + if len(xs) != n: + _add_error(f"{key} length {len(xs)} != expected {n}") + return + + for i, x in enumerate(xs): + x_n = _numel(x) + x_ndim = _ndim(x) + if x_n is None: + _add_error(f"{key}[{i}] bad type: {type(x).__name__}") + continue + if x_ndim != 1: + _add_error(f"{key}[{i}] must be 1D, got ndim={x_ndim}") + if i < len(expected_lengths) and x_n != expected_lengths[i]: + _add_error(f"{key}[{i}] length {x_n} != response_lengths[{i}] {expected_lengths[i]}") + if not _finite_tensor_or_list(x): + _add_error(f"{key}[{i}] contains NaN/Inf or non-numeric values") + + logprob_key = "rollout_log_probs" if getattr(args, "use_rollout_logprobs", False) else "log_probs" + + if require_log_probs and not _present(logprob_key): + _add_error(f"require_log_probs=True but {logprob_key!r} is missing/None") + + # rollout_log_probs are CP-sliced by get_rollout_data before this validator + # runs, so their per-sample lengths differ from response_lengths[i] on any + # rank when cp_size > 1. Skip the length check in that case. + cp_size = int(_maybe_attr(ps, "cp", "size", default=1) or 1) + for key in ("log_probs", "ref_log_probs", "values", "advantages", "returns"): + check_vector_list(key, response_lengths) + check_vector_list("rollout_log_probs", [] if cp_size > 1 else response_lengths) + + # GRPO grouping diagnostics. Warning only because dynamic filtering can alter counts. + n_samples_per_prompt = int(getattr(args, "n_samples_per_prompt", 0) or 0) + if n_samples_per_prompt > 0 and n % n_samples_per_prompt != 0: + _add_warning(f"sample count {n} not divisible by n_samples_per_prompt={n_samples_per_prompt}") + + grpo_group_size = int(getattr(args, "grpo_group_size", 0) or 0) + if grpo_group_size > 0 and n % grpo_group_size != 0: + _add_warning(f"sample count {n} not divisible by grpo_group_size={grpo_group_size}") + + # This is important for your failure mode: + # If compute_advantages_and_returns will normalize, every rank that reaches + # it must have log_probs/values in the same structural state. + if getattr(args, "compute_advantages_and_returns", False) and getattr(args, "normalize_advantages", False): + if _present(logprob_key): + _add_warning( + f"normalization path will enter distributed whitening with {logprob_key}; " + "if another rank is missing this key, it can skip or fail differently" + ) + + if warnings: + for w in warnings[:50]: + _log("warning", w) + if len(warnings) > 50: + _log("warning", f"... {len(warnings) - 50} additional warnings omitted") + + if errors: + _log("error", "summary_before_failure " + _basic_batch_summary()) + for e in errors[:100]: + _log("error", "failure " + e) + if len(errors) > 100: + _log("error", f"... {len(errors) - 100} additional errors omitted") + _log("error", "trace_at_validation_failure\n" + "".join(traceback.format_stack(limit=12))) + raise ValueError(f"{where}: rollout validation failed with {len(errors)} error(s); see logs above") + + _log("warning", "success " + _basic_batch_summary()) + class MegatronTrainRayActor(TrainRayActor): @with_defer(lambda: Timer().start("train_wait")) @@ -365,6 +842,14 @@ def _use_rollout_replay(self, m) -> bool: return getattr(self.args, f"use_rollout_{m.name}_replay") def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None: + validate_rollout_for_grpo_training_step( + self.args, + rollout_data, + where=f"train_actor.rollout_id={rollout_id}.initial", + logger=logger, + require_log_probs=False, + ) + # Create data iterator for log_probs and train. data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 17905500a3..ebe44f8fb9 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -551,11 +551,13 @@ def check_weights(self, action: str): def _get_rollout_data(self, rollout_id): if self.args.load_debug_rollout_data: - data = torch.load( - self.args.load_debug_rollout_data.format(rollout_id=rollout_id), - weights_only=False, - )["samples"] - data = [Sample.from_dict(sample) for sample in data] + path = Path(self.args.load_debug_rollout_data.format(rollout_id=rollout_id)) + if path.suffix == ".parquet": + import pyarrow.parquet as pq + data = [Sample.from_dict(row) for row in pq.read_table(path).to_pylist()] + else: + data = torch.load(path, weights_only=False)["samples"] + data = [Sample.from_dict(sample) for sample in data] if (ratio := self.args.load_debug_rollout_data_subsample) is not None: original_num_rows = len(data) rough_subsample_num_rows = int(original_num_rows * ratio) @@ -632,17 +634,39 @@ def _save_debug_rollout_data(self, data, rollout_id, evaluation: bool): logger.info(f"Save debug rollout data to {path}") path.parent.mkdir(parents=True, exist_ok=True) - # TODO may improve the format if evaluation: - dump_data = dict( - samples=[sample.to_dict() for dataset_name, info in data.items() for sample in info["samples"]] - ) + samples = [sample.to_dict() for dataset_name, info in data.items() for sample in info["samples"]] else: - dump_data = dict( - samples=[sample.to_dict() for sample in data], - ) + samples = [sample.to_dict() for sample in data] - torch.save(dict(rollout_id=rollout_id, **dump_data), path) + save_format = getattr(self.args, "save_rollout_format", "pt") + + if save_format == "parquet": + import pyarrow as pa + import pyarrow.parquet as pq + path = path.with_suffix(".parquet") + table = pa.Table.from_pylist(samples) + table = table.replace_schema_metadata({b"rollout_id": str(rollout_id).encode()}) + pq.write_table(table, path, compression="snappy") + else: + torch.save(dict(rollout_id=rollout_id, samples=samples), path) + + # Rolling retention: delete files that aged out of the window (training rollouts only). + # Walk backward from the oldest allowed id so that a restart with a smaller N + # cleans up all accumulated stale files, not just one. + retain_last_n = getattr(self.args, "save_rollout_retain_last_n", 0) + if not evaluation and retain_last_n > 0: + old_id = rollout_id - retain_last_n + while old_id >= 0: + old_path = Path(path_template.format(rollout_id=str(old_id))) + if save_format == "parquet": + old_path = old_path.with_suffix(".parquet") + if old_path.exists(): + old_path.unlink() + logger.info(f"Deleted aged-out rollout file {old_path} (retain_last_n={retain_last_n})") + old_id -= 1 + else: + break def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): if self.custom_reward_post_process_func is not None: @@ -748,11 +772,72 @@ def set_train_parallel_config(self, config: dict): self.train_parallel_config = config def _split_train_data_by_dp(self, data, dp_size): - """Split the train data by data parallel size.""" - rollout_data = {} + """Split the train data by data parallel size, with per-DP imbalance diagnostics.""" + import sys + import ray - if "prompt" in data: - rollout_data["prompt"] = data["prompt"] + def _safe_len(x): + try: + return len(x) + except Exception: + return 0 + + def _estimate_payload_bytes(obj): + """Cheap recursive estimate. Avoids materializing anything new.""" + seen = set() + + def walk(o): + oid = id(o) + if oid in seen: + return 0 + seen.add(oid) + + if o is None: + return 0 + + # Common tensor/array cases. + if hasattr(o, "numel") and hasattr(o, "element_size"): + try: + return int(o.numel()) * int(o.element_size()) + except Exception: + pass + + if hasattr(o, "nbytes"): + try: + return int(o.nbytes) + except Exception: + pass + + if isinstance(o, (bytes, bytearray)): + return len(o) + + if isinstance(o, str): + return len(o) + + if isinstance(o, dict): + return sum(walk(k) + walk(v) for k, v in o.items()) + + if isinstance(o, (list, tuple, range)): + return sum(walk(v) for v in o) + + try: + return sys.getsizeof(o) + except Exception: + return 0 + + return walk(obj) + + def _stat(xs): + if not xs: + return {"n": 0, "sum": 0, "min": 0, "max": 0, "avg": 0.0} + s = sum(xs) + return { + "n": len(xs), + "sum": s, + "min": min(xs), + "max": max(xs), + "avg": round(s / len(xs), 1), + } total_lengths = [len(t) for t in data["tokens"]] data["total_lengths"] = total_lengths @@ -763,11 +848,13 @@ def _split_train_data_by_dp(self, data, dp_size): partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)] rollout_data_refs = [] + dp_summaries = [] for i in range(dp_size): rollout_data = {} - partition = partitions[i] + partition = list(partitions[i]) rollout_data["partition"] = partition + for key in [ "tokens", "multimodal_train_inputs", @@ -785,9 +872,9 @@ def _split_train_data_by_dp(self, data, dp_size): ]: if key not in data: continue - val = [data[key][j] for j in partition] - rollout_data[key] = val - # keys that need to be splited at train side + rollout_data[key] = [data[key][j] for j in partition] + + # Keys intentionally copied whole and split later on train side. for key in [ "raw_reward", "total_lengths", @@ -796,7 +883,86 @@ def _split_train_data_by_dp(self, data, dp_size): if key not in data: continue rollout_data[key] = data[key] - rollout_data_refs.append(Box(ray.put(rollout_data))) + + token_lens = [total_lengths[j] for j in partition] + response_lens = ( + [data["response_lengths"][j] for j in partition] + if "response_lengths" in data + else [] + ) + loss_mask_lens = ( + [_safe_len(data["loss_masks"][j]) for j in partition] + if "loss_masks" in data + else [] + ) + + payload_bytes = _estimate_payload_bytes(rollout_data) + ref = ray.put(rollout_data) + + summary = { + "dp_rank": i, + "num_samples": len(partition), + "partition": partition, + "tokens": _stat(token_lens), + "responses": _stat(response_lens), + "loss_masks": _stat(loss_mask_lens), + "payload_mb_est": round(payload_bytes / 1024 / 1024, 2), + "object_ref": ref.hex(), + } + dp_summaries.append(summary) + + logger.warning( + "ROLLOUT_DP_SHARD " + "dp=%s samples=%s token_sum=%s token_min=%s token_max=%s token_avg=%s " + "response_sum=%s response_min=%s response_max=%s response_avg=%s " + "loss_mask_sum=%s payload_mb_est=%.2f object_ref=%s partition=%s", + summary["dp_rank"], + summary["num_samples"], + summary["tokens"]["sum"], + summary["tokens"]["min"], + summary["tokens"]["max"], + summary["tokens"]["avg"], + summary["responses"]["sum"], + summary["responses"]["min"], + summary["responses"]["max"], + summary["responses"]["avg"], + summary["loss_masks"]["sum"], + summary["payload_mb_est"], + summary["object_ref"], + summary["partition"], + ) + + rollout_data_refs.append(Box(ref)) + + token_sums = [s["tokens"]["sum"] for s in dp_summaries] + sample_counts = [s["num_samples"] for s in dp_summaries] + payload_mbs = [s["payload_mb_est"] for s in dp_summaries] + + def _ratio(xs): + xs = [x for x in xs if x is not None] + if not xs: + return 0.0 + mn = min(xs) + mx = max(xs) + return round(mx / mn, 3) if mn else float("inf") + + logger.warning( + "ROLLOUT_DP_IMBALANCE " + "dp_size=%s total_samples=%s total_tokens=%s " + "sample_counts=%s sample_ratio=%s " + "token_sums=%s token_ratio=%s " + "payload_mbs=%s payload_ratio=%s", + dp_size, + len(total_lengths), + sum(total_lengths), + sample_counts, + _ratio(sample_counts), + token_sums, + _ratio(token_sums), + payload_mbs, + _ratio(payload_mbs), + ) + return rollout_data_refs diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 717c2b6980..1d1ada930a 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1260,6 +1260,19 @@ def add_debug_arguments(parser): default=None, help="Subsample a portion of the debug rollout data for faster debugging.", ) + parser.add_argument( + "--save-rollout-format", + type=str, + choices=["pt", "parquet"], + default="pt", + help="Serialization format for rollout debug files. 'parquet' replaces the .pt extension and writes a snappy-compressed parquet file readable by polars/pandas.", + ) + parser.add_argument( + "--save-rollout-retain-last-n", + type=int, + default=0, + help="Keep only the N most recent rollout files in the save directory, deleting the one that aged out after each step. 0 (default) keeps all files.", + ) parser.add_argument( "--debug-rollout-only", action="store_true",