diff --git a/cosmos_framework/inference/_test.py b/cosmos_framework/inference/_test.py index d1a61a9..569545b 100644 --- a/cosmos_framework/inference/_test.py +++ b/cosmos_framework/inference/_test.py @@ -13,6 +13,7 @@ from typing_extensions import Self from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.inference.metrics import compute_action_mse, compute_psnr from cosmos_framework.inference.args import ( IMAGE_ONLY_RESOLUTIONS, AspectRatio, @@ -94,6 +95,63 @@ def _get_video_dims(path: Path) -> tuple[int, int, int]: return int(stream.width), int(stream.height), frame_count +def _compute_video_metrics(gt_video_cthw_uint8: torch.Tensor, pred_path: Path, mode: str) -> dict[str, float]: + from cosmos_framework.inference.vision import read_media_frames + from cosmos_framework.utils import log + + pred, _ = read_media_frames(pred_path, max_frames=gt_video_cthw_uint8.shape[1] + 1) + pred = pred[..., : gt_video_cthw_uint8.shape[-2], : gt_video_cthw_uint8.shape[-1]] + gt = gt_video_cthw_uint8 + + if pred.shape != gt.shape: + if pred.shape[-2:] != gt.shape[-2:]: + raise ValueError(f"video spatial mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})") + if mode == "vision": + min_t = min(gt.shape[1], pred.shape[1]) + log.info( + f"vision frame-count mismatch trimmed to {min_t} (gt T={gt.shape[1]}, pred T={pred.shape[1]}, " + f"{pred_path}); likely due to generation aligned to 4k+1 frames " + "(latent temporal factor 4 + 1 conditioning frame), while the on-disk GT " + "keeps the raw clip length." + ) + gt = gt[:, :min_t] + pred = pred[:, :min_t] + else: + raise ValueError(f"video shape mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})") + + return {"psnr": compute_psnr(gt, pred)} + + +def _compute_action_metrics(gt_action_td: torch.Tensor, pred_action_list: list) -> dict[str, float]: + pred = torch.tensor(pred_action_list, dtype=torch.float32) + if pred.shape != gt_action_td.shape: + raise ValueError(f"action shape mismatch: gt {tuple(gt_action_td.shape)} vs pred {tuple(pred.shape)}") + return {"action_mse": compute_action_mse(gt_action_td, pred)} + + +def _compute_sample_metrics( + mode: ModelMode, + gt_video_cthw: torch.Tensor | None, + gt_action_td: torch.Tensor | None, + sample_output: SampleOutputs, + sample_dir: Path, + vision_extension: str, +) -> dict[str, float]: + out: dict[str, float] = {"mode": mode.value, "name": sample_dir.name} + if mode in (ModelMode.FORWARD_DYNAMICS, ModelMode.POLICY): + if gt_video_cthw is None: + raise ValueError(f"mode={mode.value!r} requires GT video but data_batch had none") + out.update(_compute_video_metrics(gt_video_cthw, sample_dir / f"vision{vision_extension}", mode.value)) + if mode in (ModelMode.INVERSE_DYNAMICS, ModelMode.POLICY): + pred_action = sample_output.outputs[0].content.get("action") if sample_output.outputs else None + if pred_action is None: + raise ValueError(f"mode={mode.value!r} requires predicted action but content has none") + if gt_action_td is None: + raise ValueError(f"mode={mode.value!r} requires GT action but data_batch had none") + out.update(_compute_action_metrics(gt_action_td, pred_action)) + return out + + def _omni_after_script(runner: ScriptRunner, cfg: ScriptConfig) -> None: inference_dir = runner.output_dir / "inference" sample_outputs_list = _check_inference_output([runner.input_dir / "omni/*json"], inference_dir) @@ -218,15 +276,10 @@ def _check_action_golden(sample_outputs: SampleOutputs, sample_dir: Path) -> lis if psnr_min is None and mse_max is None: return [] - from cosmos_framework.scripts.eval_utils import compute_sample_metrics - mode = ModelMode(sample_outputs.args["model_mode"]) gt_video = _load_canonical_gt_video(sample_dir) if mode in (ModelMode.FORWARD_DYNAMICS, ModelMode.POLICY) else None gt_action = _load_gt_action(extra["golden_action_path"], sample_dir) if mse_max is not None else None - # `compute_sample_metrics` parses mode from `name.split("/")[-2]`. - metrics = compute_sample_metrics( - f"{mode.value}/{sample_dir.name}", gt_video, gt_action, sample_outputs, sample_dir, ".mp4" - ) + metrics = _compute_sample_metrics(mode, gt_video, gt_action, sample_outputs, sample_dir, ".mp4") failures: list[str] = [] if psnr_min is not None and "psnr" in metrics: diff --git a/cosmos_framework/scripts/eval_utils.py b/cosmos_framework/scripts/eval_utils.py deleted file mode 100644 index bd6f043..0000000 --- a/cosmos_framework/scripts/eval_utils.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Helpers for `eval.py`: per-sample metric computation and aggregation.""" - -import json -from collections import defaultdict -from pathlib import Path -from typing import Any - -import numpy as np -import torch - -from cosmos_framework.inference.common.args import SampleOutputs -from cosmos_framework.inference.vision import read_media_frames -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.action.transforms import remove_reflection_padding -from cosmos_framework.inference.metrics import compute_action_mse, compute_psnr - -VIDEO_MODES = {"forward_dynamics"} -ACTION_MODES = {"inverse_dynamics"} -BOTH_MODES = {"policy"} -ALL_MODES = VIDEO_MODES | ACTION_MODES | BOTH_MODES - - -def extract_gt_video(data_batch: dict) -> torch.Tensor | None: - """Snapshot the GT video as (C, T, H, W) uint8, trimmed to its content region if padded. - - Must be called BEFORE the inference pipeline runs — the model normalizes - `data_batch["video"]` in place from uint8 [0, 255] to float [-1, 1]. - """ - video = data_batch.get("video") - if video is None: - return None - gt_video = video[0].detach().clone() - image_size = data_batch.get("image_size") - if image_size is not None: - gt_video = remove_reflection_padding(gt_video, image_size[0]) - return gt_video - - -def extract_gt_action(data_batch: dict) -> torch.Tensor | None: - """Snapshot the GT action as a (T, D) float32 tensor, or None when absent.""" - action = data_batch.get("action", [None])[0] - if action is None: - return None - - raw_action_dim = data_batch.get("raw_action_dim", [None])[0] - if raw_action_dim is not None: - # If raw_action_dim is provided, it indicates that the GT action has been padded to a larger size. - # We trim the action to its original dimension before returning it. - raw_action_dim = int(raw_action_dim.item()) # remove batch dim and convert to int - assert action.shape[-1] >= raw_action_dim, ( - f"invalid raw_action_dim={raw_action_dim} for action with shape {action.shape}" - ) - action = action[..., :raw_action_dim] - - return action.detach().clone().float() - - -def _parse_mode_from_name(name: str) -> str: - parts = name.split("/") - if len(parts) < 2: - raise ValueError(f"unexpected sample name: {name!r}") - mode = parts[-2] - if mode not in ALL_MODES: - raise ValueError(f"unexpected mode {mode!r} in sample name {name!r}; expected one of {sorted(ALL_MODES)}") - return mode - - -def derive_match_key_and_group(pred_path: Path, predictions_dir: Path) -> tuple[str, str]: - """Path → ``(match_key, group)``. Used by vision eval to pair predictions with GT. - - For ``inference.py``-style outputs (basename ``vision.*``), ``match_key`` is the - parent directory name and ``group`` is the path between *predictions_dir* and - that directory. Otherwise ``match_key`` is the filename stem. - - Examples (with ``predictions_dir=/root``): - ``/root/t2v/episode_0/vision.mp4`` → ``("episode_0", "t2v")`` - ``/root/sub/foo.mp4`` → ``("foo", "sub")`` - """ - pred_path = pred_path.resolve() - predictions_dir = predictions_dir.resolve() - if not pred_path.is_relative_to(predictions_dir): - raise ValueError(f"pred_path {pred_path} is not under predictions_dir {predictions_dir}") - rel = pred_path.relative_to(predictions_dir) - parts = rel.parts - if pred_path.name.startswith("vision."): - if len(parts) < 2: - raise ValueError(f"expected //vision.* under predictions_dir, got rel={rel}") - match_key = parts[-2] - group_parts = parts[:-2] - else: - match_key = pred_path.stem - group_parts = parts[:-1] - group = "/".join(group_parts) - return match_key, group - - -def compute_video_metrics( - gt_video_cthw_uint8: torch.Tensor, - pred_path: Path, - mode: str, -) -> dict[str, float]: - """Compute per-clip PSNR. Temporal-mismatch policy depends on *mode*: - - - ``"vision"``: lenient. If ``T_gt != T_pred`` (after the H/W top-left crop), trim both to - ``min(T_gt, T_pred)`` from the start (VFM generation is aligned to ``4k+1`` frames — - latent temporal factor 4 + 1 conditioning frame — while the on-disk GT keeps the raw - clip length, so small T deltas (e.g. GT=96, pred=93) are expected and treating them as - hard errors loses the entire eval) and log an info line. - - any other mode (``forward_dynamics``, ``policy``, ...): strict. Pred T is fixed by - the action chunk size; a mismatch indicates a real bug. - - Spatial (H/W) mismatch always errors — the existing top-left crop of pred to GT's - H/W stays in place; mismatches that survive the crop indicate a config bug rather - than an SFT-style trim. - """ - # +1 so an over-long prediction surfaces as a shape mismatch instead of silent truncation. - pred, _ = read_media_frames(pred_path, max_frames=gt_video_cthw_uint8.shape[1] + 1) - # Match GT's spatial dims (top-left crop, mirroring remove_reflection_padding's convention) - # so a reflection-padded GT trimmed to its content region can be compared against the - # padded mp4 saved to disk. - pred = pred[..., : gt_video_cthw_uint8.shape[-2], : gt_video_cthw_uint8.shape[-1]] - gt = gt_video_cthw_uint8 - - if pred.shape != gt.shape: - # Spatial mismatch (after the top-left crop above) is always a hard error. - if pred.shape[-2:] != gt.shape[-2:]: - raise ValueError(f"video spatial mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})") - # Temporal mismatch: lenient for vision eval, strict otherwise. - if mode == "vision": - min_t = min(gt.shape[1], pred.shape[1]) - log.info( - f"vision frame-count mismatch trimmed to {min_t} (gt T={gt.shape[1]}, pred T={pred.shape[1]}, " - f"{pred_path}); likely due to generation aligned to 4k+1 frames " - "(latent temporal factor 4 + 1 conditioning frame), while the on-disk GT " - "keeps the raw clip length." - ) - gt = gt[:, :min_t] - pred = pred[:, :min_t] - else: - raise ValueError(f"video shape mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})") - - return {"psnr": compute_psnr(gt, pred)} - - -def _compute_action_metrics(gt_action_td: torch.Tensor, pred_action_list: list) -> dict[str, Any]: - pred = torch.tensor(pred_action_list, dtype=torch.float32) - if pred.shape != gt_action_td.shape: - raise ValueError(f"action shape mismatch: gt {tuple(gt_action_td.shape)} vs pred {tuple(pred.shape)}") - return {"action_mse": compute_action_mse(gt_action_td, pred)} - - -def compute_sample_metrics( - name: str, - gt_video_cthw: torch.Tensor | None, - gt_action_td: torch.Tensor | None, - sample_output: SampleOutputs, - sample_dir: Path, - vision_extension: str, -) -> dict[str, Any]: - """Compute metrics for a single sample, dispatched by the mode parsed from `name`.""" - mode = _parse_mode_from_name(name) - out: dict[str, Any] = {"mode": mode, "name": sample_dir.name} - if mode in VIDEO_MODES | BOTH_MODES: - if gt_video_cthw is None: - raise ValueError(f"mode={mode!r} requires GT video but data_batch had none") - out.update(compute_video_metrics(gt_video_cthw, sample_dir / f"vision{vision_extension}", mode)) - if mode in ACTION_MODES | BOTH_MODES: - pred_action = sample_output.outputs[0].content.get("action") if sample_output.outputs else None - if pred_action is None: - raise ValueError(f"mode={mode!r} requires predicted action but content has none") - if gt_action_td is None: - raise ValueError(f"mode={mode!r} requires GT action but data_batch had none") - out.update(_compute_action_metrics(gt_action_td, pred_action)) - return out - - -def aggregate_metrics(output_dir: Path) -> dict[str, Any]: - """Walk `output_dir` for per-sample `metrics.json` files; emit per-mode/metric summary. - - Each scalar metric is summarised as ``{mean, count}``. - """ - totals: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) - for f in output_dir.rglob("metrics.json"): - m = json.loads(f.read_text()) - mode = m.pop("mode", None) - m.pop("name", None) - if mode is None: - continue - for k, v in m.items(): - if isinstance(v, dict): - for sub_k, sub_v in v.items(): - totals[mode][f"{k}/{sub_k}"].append(float(sub_v)) - else: - totals[mode][k].append(float(v)) - return { - mode: {metric: {"mean": float(np.mean(vals)), "count": len(vals)} for metric, vals in metrics.items()} - for mode, metrics in totals.items() - } diff --git a/cosmos_framework/scripts/eval_utils_test.py b/cosmos_framework/scripts/eval_utils_test.py deleted file mode 100644 index fbf0112..0000000 --- a/cosmos_framework/scripts/eval_utils_test.py +++ /dev/null @@ -1,386 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Unit tests for :mod:`cosmos_framework.scripts.eval_utils` aggregation and score-only semantics.""" - -from __future__ import annotations - -import json -import math -import sys -from pathlib import Path -from unittest.mock import patch - -import pytest -import torch - -from cosmos_framework.scripts.eval_utils import ( - aggregate_metrics, - compute_video_metrics, - derive_match_key_and_group, -) - -# ``cosmos_framework.scripts.eval`` calls ``init_script()`` at import time, which raises -# if ``imaginaire`` is already loaded. The package-level ``conftest.py`` loads -# ``imaginaire.lazy_config`` during collection, so by the time this test module -# is imported, the strict check would fire. We patch the underlying -# ``_init_script`` to a no-op for the rest of this test module — the real -# init work (env-var setup, error handlers, grad disable) is irrelevant for -# unit tests that don't actually run inference. -with patch("cosmos_framework.inference.common.init._init_script", lambda **kwargs: None): - if "cosmos_framework.scripts.eval" in sys.modules: - del sys.modules["cosmos_framework.scripts.eval"] - from cosmos_framework.inference.dataset import DatasetArgs # noqa: E402 - from cosmos_framework.scripts.eval import ( - EvalArgs, # noqa: E402 - eval_vision, # noqa: E402 - ) - -pytestmark = [pytest.mark.L0, pytest.mark.CPU] - - -# --------------------------------------------------------------------------- -# aggregate_metrics — mean / count for every scalar metric -# --------------------------------------------------------------------------- - - -def _write_metrics(tmp_path: Path, name: str, mode: str, values: dict) -> None: - d = tmp_path / mode / name - d.mkdir(parents=True, exist_ok=True) - (d / "metrics.json").write_text(json.dumps({"mode": mode, "name": name, **values})) - - -def test_aggregate_metrics_empty_dir_returns_empty(tmp_path): - assert aggregate_metrics(tmp_path) == {} - - -def test_aggregate_metrics_single_sample(tmp_path): - _write_metrics(tmp_path, "s0", "vision", {"psnr": 20.0}) - out = aggregate_metrics(tmp_path) - assert out == {"vision": {"psnr": {"mean": 20.0, "count": 1}}} - - -def test_aggregate_metrics_mean_correct(tmp_path): - for i, v in enumerate([10.0, 20.0, 30.0]): - _write_metrics(tmp_path, f"s{i}", "vision", {"psnr": v}) - out = aggregate_metrics(tmp_path)["vision"]["psnr"] - assert out["count"] == 3 - assert math.isclose(out["mean"], 20.0) - - -def test_aggregate_metrics_skips_files_without_mode(tmp_path): - d = tmp_path / "orphan" - d.mkdir() - (d / "metrics.json").write_text(json.dumps({"name": "x", "psnr": 99.0})) - _write_metrics(tmp_path, "s0", "vision", {"psnr": 20.0}) - out = aggregate_metrics(tmp_path) - assert set(out.keys()) == {"vision"} - assert out["vision"]["psnr"]["count"] == 1 - - -def test_aggregate_metrics_separates_modes(tmp_path): - _write_metrics(tmp_path, "s0", "vision", {"psnr": 20.0}) - _write_metrics(tmp_path, "s1", "forward_dynamics", {"psnr": 24.0}) - out = aggregate_metrics(tmp_path) - assert set(out.keys()) == {"vision", "forward_dynamics"} - assert out["vision"]["psnr"]["mean"] == 20.0 - assert out["forward_dynamics"]["psnr"]["mean"] == 24.0 - - -def test_aggregate_metrics_flattens_nested_dicts(tmp_path): - # nested dicts (e.g. grouped action_mse) flatten to "k/sub_k" - _write_metrics(tmp_path, "s0", "policy", {"group_mse": {"arm": 0.1, "gripper": 0.2}}) - _write_metrics(tmp_path, "s1", "policy", {"group_mse": {"arm": 0.3, "gripper": 0.4}}) - out = aggregate_metrics(tmp_path)["policy"] - assert "group_mse/arm" in out and "group_mse/gripper" in out - assert math.isclose(out["group_mse/arm"]["mean"], 0.2) - assert math.isclose(out["group_mse/gripper"]["mean"], 0.3) - - -# --------------------------------------------------------------------------- -# compute_video_metrics — vision lenient T-trim vs strict (action) on mismatch -# --------------------------------------------------------------------------- - - -def _write_synthetic_mp4(path: Path, frames_cthw_uint8: torch.Tensor, fps: int = 5) -> None: - """Write a (C, T, H, W) uint8 tensor as an mp4 via torchvision. - - Lossy encoding will shift pixel values slightly; tests assert structural - properties (shape, presence of metrics) rather than exact PSNR. - """ - import torchvision.io as tvio - - # write_video expects (T, H, W, C) uint8 - thwc = frames_cthw_uint8.permute(1, 2, 3, 0).contiguous() - tvio.write_video(str(path), thwc, fps=fps) - - -def test_compute_video_metrics_vision_lenient_trims_to_min_t(tmp_path, caplog): - """VFM mode: pred has fewer frames than GT → trim both to min(T), warn, return metrics.""" - g = torch.Generator().manual_seed(0) - gt = torch.randint(0, 256, (3, 8, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_path = tmp_path / "vision.mp4" - _write_synthetic_mp4(pred_path, pred_frames) - - with caplog.at_level("WARNING"): - metrics = compute_video_metrics(gt, pred_path, mode="vision") - - assert "psnr" in metrics - - -def test_compute_video_metrics_vision_no_warning_on_matching_shapes(tmp_path, caplog): - """VFM mode: matching shapes → no warning, full metrics.""" - g = torch.Generator().manual_seed(1) - gt = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_path = tmp_path / "vision.mp4" - _write_synthetic_mp4(pred_path, pred_frames) - - with caplog.at_level("WARNING"): - metrics = compute_video_metrics(gt, pred_path, mode="vision") - - assert "psnr" in metrics - assert "trimmed to" not in caplog.text.lower() - - -def test_compute_video_metrics_action_strict_on_t_mismatch(tmp_path): - """forward_dynamics: T mismatch still raises ValueError (the action chunk size is fixed).""" - g = torch.Generator().manual_seed(2) - gt = torch.randint(0, 256, (3, 8, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_path = tmp_path / "vision.mp4" - _write_synthetic_mp4(pred_path, pred_frames) - - with pytest.raises(ValueError, match="shape mismatch"): - compute_video_metrics(gt, pred_path, mode="forward_dynamics") - - -def test_compute_video_metrics_vision_spatial_mismatch_still_errors(tmp_path): - """Even in vision mode, an H mismatch that survives the top-left crop is a hard error.""" - g = torch.Generator().manual_seed(3) - gt = torch.randint(0, 256, (3, 5, 16, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_frames = torch.randint(0, 256, (3, 5, 8, 32), generator=g, dtype=torch.int64).to(torch.uint8) - pred_path = tmp_path / "vision.mp4" - _write_synthetic_mp4(pred_path, pred_frames) - - with pytest.raises(ValueError, match="spatial mismatch"): - compute_video_metrics(gt, pred_path, mode="vision") - - -# --------------------------------------------------------------------------- -# derive_match_key_and_group — generic path-structure-based pairing rule -# --------------------------------------------------------------------------- - - -def test_derive_match_key_and_group_user_tree_cosmos_nano(tmp_path): - """Tree 1: /cosmos_nano_t2w/episode_*/vision.mp4 → key=episode_*, group=cosmos_nano_t2w.""" - p = tmp_path / "cosmos_nano_t2w" / "episode_002345_clip000" / "vision.mp4" - p.parent.mkdir(parents=True) - p.touch() - key, group = derive_match_key_and_group(p, tmp_path) - assert key == "episode_002345_clip000" - assert group == "cosmos_nano_t2w" - - -def test_derive_match_key_and_group_user_tree_mixed_modality(tmp_path): - """Tree 2: /mixed_modality_*/t2v/episode_*/vision.mp4 → group=mixed_modality_*/t2v.""" - p = tmp_path / "mixed_modality_sft_8b_0507e" / "t2v" / "episode_002345_clip000" / "vision.mp4" - p.parent.mkdir(parents=True) - p.touch() - key, group = derive_match_key_and_group(p, tmp_path) - assert key == "episode_002345_clip000" - assert group == "mixed_modality_sft_8b_0507e/t2v" - - -def test_derive_match_key_and_group_flat_layout(tmp_path): - """Flat: //vision.mp4 → key=, group empty string.""" - p = tmp_path / "clip0" / "vision.mp4" - p.parent.mkdir(parents=True) - p.touch() - key, group = derive_match_key_and_group(p, tmp_path) - assert key == "clip0" - assert group == "" - - -def test_derive_match_key_and_group_inference_py_output(tmp_path): - """Canonical inference.py output: //vision.mp4.""" - p = tmp_path / "t2v" / "episode_049683_clip000" / "vision.mp4" - p.parent.mkdir(parents=True) - p.touch() - key, group = derive_match_key_and_group(p, tmp_path) - assert key == "episode_049683_clip000" - assert group == "t2v" - - -def test_derive_match_key_and_group_non_vision_filename_uses_stem(tmp_path): - """If basename isn't vision.*, the filename stem becomes the key (no parent-dir drop).""" - p = tmp_path / "sub" / "foo.mp4" - p.parent.mkdir(parents=True) - p.touch() - key, group = derive_match_key_and_group(p, tmp_path) - assert key == "foo" - assert group == "sub" - - -def test_derive_match_key_and_group_rejects_path_outside_predictions_dir(tmp_path): - other = tmp_path.parent / "elsewhere" / "vision.mp4" - with pytest.raises(ValueError, match="not under predictions_dir"): - derive_match_key_and_group(other, tmp_path) - - -# --------------------------------------------------------------------------- -# score_only end-to-end (no model, CPU) -# --------------------------------------------------------------------------- - - -def test_score_only_end_to_end(tmp_path, monkeypatch): - """Build a synthetic GT dir + predictions tree, run score_only, assert sidecars + aggregate.""" - from cosmos_framework.inference.args import OmniSetupOverrides - - gt_dir = tmp_path / "gt" - pred_dir = tmp_path / "preds" - out_dir = tmp_path / "out" - gt_dir.mkdir() - pred_dir.mkdir() - out_dir.mkdir() - - # Three clips; T=5, 32x32. - keys = ["clipA", "clipB", "clipC"] - g = torch.Generator().manual_seed(0) - for k in keys: - frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - _write_synthetic_mp4(gt_dir / f"{k}.mp4", frames) - bucket = pred_dir / "model_x" / "t2v" / k - bucket.mkdir(parents=True) - # Pred = lightly-perturbed GT so PSNR is sane (codec lossy regardless). - pred = (frames.float() + 4).clamp(0, 255).to(torch.uint8) - _write_synthetic_mp4(bucket / "vision.mp4", pred) - - args = EvalArgs( - setup=OmniSetupOverrides(output_dir=out_dir, checkpoint_path=""), - dataset=DatasetArgs(model_mode="vision"), - gt_dir=gt_dir, - predictions_dir=pred_dir, - predictions_glob="**/vision.mp4", - ) - eval_vision(args) - - # Per-sample sidecars. - for k in keys: - m = json.loads((out_dir / "model_x/t2v" / k / "metrics.json").read_text()) - assert m["mode"] == "model_x/t2v" - assert m["name"] == k - assert "psnr" in m - - # Aggregate. - agg = json.loads((out_dir / "metrics_aggregate.json").read_text()) - assert "model_x/t2v" in agg - for metric in ("psnr",): - entry = agg["model_x/t2v"][metric] - assert entry["count"] == 3 - assert "mean" in entry - - -def test_score_only_missing_gt_logs_warning_and_skips(tmp_path, caplog): - from cosmos_framework.inference.args import OmniSetupOverrides - - gt_dir = tmp_path / "gt" - pred_dir = tmp_path / "preds" - out_dir = tmp_path / "out" - gt_dir.mkdir() - pred_dir.mkdir() - out_dir.mkdir() - - g = torch.Generator().manual_seed(1) - frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - # GT only has clipA; pred has both clipA + clipZ. - _write_synthetic_mp4(gt_dir / "clipA.mp4", frames) - for k in ("clipA", "clipZ"): - d = pred_dir / "m" / k - d.mkdir(parents=True) - _write_synthetic_mp4(d / "vision.mp4", frames) - - args = EvalArgs( - setup=OmniSetupOverrides(output_dir=out_dir, checkpoint_path=""), - dataset=DatasetArgs(model_mode="vision"), - gt_dir=gt_dir, - predictions_dir=pred_dir, - predictions_glob="**/vision.mp4", - ) - with caplog.at_level("WARNING"): - eval_vision(args) - - # clipA scored, clipZ skipped. - assert (out_dir / "m" / "clipA" / "metrics.json").exists() - assert not (out_dir / "m" / "clipZ" / "metrics.json").exists() - agg = json.loads((out_dir / "metrics_aggregate.json").read_text()) - assert agg["m"]["psnr"]["count"] == 1 - - -def test_score_only_single_mode_bucket(tmp_path): - """Predictions under one mode subdir → one bucket in aggregate (no t2v/i2v/v2v assumption).""" - from cosmos_framework.inference.args import OmniSetupOverrides - - gt_dir = tmp_path / "gt" - pred_dir = tmp_path / "preds" - out_dir = tmp_path / "out" - gt_dir.mkdir() - pred_dir.mkdir() - out_dir.mkdir() - - keys = ["clipA", "clipB"] - g = torch.Generator().manual_seed(0) - for k in keys: - frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - _write_synthetic_mp4(gt_dir / f"{k}.mp4", frames) - d = pred_dir / "t2v" / k - d.mkdir(parents=True) - _write_synthetic_mp4(d / "vision.mp4", (frames.float() + 4).clamp(0, 255).to(torch.uint8)) - - args = EvalArgs( - setup=OmniSetupOverrides(output_dir=out_dir, checkpoint_path=""), - dataset=DatasetArgs(model_mode="vision"), - gt_dir=gt_dir, - predictions_dir=pred_dir, - predictions_glob="**/vision.mp4", - ) - eval_vision(args) - - agg = json.loads((out_dir / "metrics_aggregate.json").read_text()) - assert set(agg.keys()) == {"t2v"} - assert agg["t2v"]["psnr"]["count"] == 2 - - -def test_score_only_flat_layout_uses_default_bucket(tmp_path): - """Flat ``//vision.mp4`` (no mode subfolder) → bucket=='default'.""" - from cosmos_framework.inference.args import OmniSetupOverrides - - gt_dir = tmp_path / "gt" - pred_dir = tmp_path / "preds" - out_dir = tmp_path / "out" - gt_dir.mkdir() - pred_dir.mkdir() - out_dir.mkdir() - - keys = ["clipA", "clipB"] - g = torch.Generator().manual_seed(0) - for k in keys: - frames = torch.randint(0, 256, (3, 5, 32, 32), generator=g, dtype=torch.int64).to(torch.uint8) - _write_synthetic_mp4(gt_dir / f"{k}.mp4", frames) - d = pred_dir / k - d.mkdir(parents=True) - _write_synthetic_mp4(d / "vision.mp4", (frames.float() + 4).clamp(0, 255).to(torch.uint8)) - - args = EvalArgs( - setup=OmniSetupOverrides(output_dir=out_dir, checkpoint_path=""), - dataset=DatasetArgs(model_mode="vision"), - gt_dir=gt_dir, - predictions_dir=pred_dir, - predictions_glob="**/vision.mp4", - ) - eval_vision(args) - - agg = json.loads((out_dir / "metrics_aggregate.json").read_text()) - assert set(agg.keys()) == {"default"} - assert agg["default"]["psnr"]["count"] == 2