Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions cosmos_framework/inference/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import Self

from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO
from cosmos_framework.inference.metrics import compute_action_mse, compute_psnr
from cosmos_framework.inference.args import (
IMAGE_ONLY_RESOLUTIONS,
AspectRatio,
Expand Down Expand Up @@ -94,6 +95,63 @@ def _get_video_dims(path: Path) -> tuple[int, int, int]:
return int(stream.width), int(stream.height), frame_count


def _compute_video_metrics(gt_video_cthw_uint8: torch.Tensor, pred_path: Path, mode: str) -> dict[str, float]:
from cosmos_framework.inference.vision import read_media_frames
from cosmos_framework.utils import log

pred, _ = read_media_frames(pred_path, max_frames=gt_video_cthw_uint8.shape[1] + 1)
pred = pred[..., : gt_video_cthw_uint8.shape[-2], : gt_video_cthw_uint8.shape[-1]]
gt = gt_video_cthw_uint8

if pred.shape != gt.shape:
if pred.shape[-2:] != gt.shape[-2:]:
raise ValueError(f"video spatial mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})")
if mode == "vision":
min_t = min(gt.shape[1], pred.shape[1])
log.info(
f"vision frame-count mismatch trimmed to {min_t} (gt T={gt.shape[1]}, pred T={pred.shape[1]}, "
f"{pred_path}); likely due to generation aligned to 4k+1 frames "
"(latent temporal factor 4 + 1 conditioning frame), while the on-disk GT "
"keeps the raw clip length."
)
gt = gt[:, :min_t]
pred = pred[:, :min_t]
else:
raise ValueError(f"video shape mismatch: gt {tuple(gt.shape)} vs pred {tuple(pred.shape)} ({pred_path})")

return {"psnr": compute_psnr(gt, pred)}


def _compute_action_metrics(gt_action_td: torch.Tensor, pred_action_list: list) -> dict[str, float]:
pred = torch.tensor(pred_action_list, dtype=torch.float32)
if pred.shape != gt_action_td.shape:
raise ValueError(f"action shape mismatch: gt {tuple(gt_action_td.shape)} vs pred {tuple(pred.shape)}")
return {"action_mse": compute_action_mse(gt_action_td, pred)}


def _compute_sample_metrics(
mode: ModelMode,
gt_video_cthw: torch.Tensor | None,
gt_action_td: torch.Tensor | None,
sample_output: SampleOutputs,
sample_dir: Path,
vision_extension: str,
) -> dict[str, float]:
out: dict[str, float] = {"mode": mode.value, "name": sample_dir.name}
if mode in (ModelMode.FORWARD_DYNAMICS, ModelMode.POLICY):
if gt_video_cthw is None:
raise ValueError(f"mode={mode.value!r} requires GT video but data_batch had none")
out.update(_compute_video_metrics(gt_video_cthw, sample_dir / f"vision{vision_extension}", mode.value))
if mode in (ModelMode.INVERSE_DYNAMICS, ModelMode.POLICY):
pred_action = sample_output.outputs[0].content.get("action") if sample_output.outputs else None
if pred_action is None:
raise ValueError(f"mode={mode.value!r} requires predicted action but content has none")
if gt_action_td is None:
raise ValueError(f"mode={mode.value!r} requires GT action but data_batch had none")
out.update(_compute_action_metrics(gt_action_td, pred_action))
return out


def _omni_after_script(runner: ScriptRunner, cfg: ScriptConfig) -> None:
inference_dir = runner.output_dir / "inference"
sample_outputs_list = _check_inference_output([runner.input_dir / "omni/*json"], inference_dir)
Expand Down Expand Up @@ -218,15 +276,10 @@ def _check_action_golden(sample_outputs: SampleOutputs, sample_dir: Path) -> lis
if psnr_min is None and mse_max is None:
return []

from cosmos_framework.scripts.eval_utils import compute_sample_metrics

mode = ModelMode(sample_outputs.args["model_mode"])
gt_video = _load_canonical_gt_video(sample_dir) if mode in (ModelMode.FORWARD_DYNAMICS, ModelMode.POLICY) else None
gt_action = _load_gt_action(extra["golden_action_path"], sample_dir) if mse_max is not None else None
# `compute_sample_metrics` parses mode from `name.split("/")[-2]`.
metrics = compute_sample_metrics(
f"{mode.value}/{sample_dir.name}", gt_video, gt_action, sample_outputs, sample_dir, ".mp4"
)
metrics = _compute_sample_metrics(mode, gt_video, gt_action, sample_outputs, sample_dir, ".mp4")

failures: list[str] = []
if psnr_min is not None and "psnr" in metrics:
Expand Down
201 changes: 0 additions & 201 deletions cosmos_framework/scripts/eval_utils.py

This file was deleted.

Loading