diff --git a/cosmos_framework/configs/base/defaults/open_source_dataloader.py b/cosmos_framework/configs/base/defaults/open_source_dataloader.py index dc84e8f..c52079e 100644 --- a/cosmos_framework/configs/base/defaults/open_source_dataloader.py +++ b/cosmos_framework/configs/base/defaults/open_source_dataloader.py @@ -26,20 +26,20 @@ from hydra.core.config_store import ConfigStore +from cosmos_framework.configs.base.defaults.vlm import create_qwen2_tokenizer_with_download from cosmos_framework.data.vfm.joint_dataloader import ( PackingDataLoader, RankPartitionedDataLoader, ) from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.configs.base.defaults.vlm import create_qwen2_tokenizer_with_download - # --------------------------------------------------------------------------- # Inner: SFT video dataset (matches the inline ``get_sft_dataset`` call in the # reference YAML). # --------------------------------------------------------------------------- + def get_sft_video_dataset( *, jsonl_paths: list[str], @@ -55,6 +55,11 @@ def get_sft_video_dataset( append_duration_fps_timestamps: bool = True, append_resolution_info: bool = True, use_system_prompt: bool = False, + # Structured-JSON captions are far longer than dense prose; raise the token + # budget so the loader does not truncate them mid-JSON (see sft_dataset.py + # _MAX_NUM_TOKENS). 2048 covers the example dataset (measured max ~1790 Qwen + # tokens) with margin; keep consistent with the inference prompt budget. + max_num_tokens: int = 2048, caption_suffix: str = "", cfg_dropout_rate: float = 0.1, cfg_dropout_keep_metadata: bool = False, @@ -85,6 +90,7 @@ def get_sft_video_dataset( append_duration_fps_timestamps=append_duration_fps_timestamps, append_resolution_info=append_resolution_info, use_system_prompt=use_system_prompt, + max_num_tokens=max_num_tokens, caption_suffix=caption_suffix, cfg_dropout_rate=cfg_dropout_rate, cfg_dropout_keep_metadata=cfg_dropout_keep_metadata, @@ -103,6 +109,7 @@ def get_sft_video_dataset( # pipeline. This is the registered config_store node. # --------------------------------------------------------------------------- + def get_open_source_sft_dataloader( *, jsonl_paths: list[str] | None = None, @@ -165,6 +172,7 @@ def get_open_source_sft_dataloader( # ConfigStore registration. # --------------------------------------------------------------------------- + def register_open_source_dataloaders() -> None: """Register named dataloader configs under the ``data_train`` Hydra group. diff --git a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py index 102d525..094c539 100644 --- a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py +++ b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py @@ -34,15 +34,14 @@ from hydra.core.config_store import ConfigStore -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.lazy_config import LazyDict - from cosmos_framework.configs.base.experiment.sft.models.nano_model_config import NANO_MODEL_CONFIG from cosmos_framework.data.vfm.joint_dataloader import ( PackingDataLoader, RankPartitionedDataLoader, ) from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.utils.lazy_config import LazyDict cs = ConfigStore.instance() @@ -237,6 +236,10 @@ dataset=L(get_sft_dataset)( append_duration_fps_timestamps=True, append_resolution_info=True, + # Structured-JSON captions are long; raise the token budget so + # the loader does not truncate them (see sft_dataset.py + # _MAX_NUM_TOKENS). 2048 covers the example set (measured max ~1790). + max_num_tokens=2048, caption_suffix="", cfg_dropout_keep_metadata=False, cfg_dropout_rate=0.1, diff --git a/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py b/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py index 385f49a..a0ff386 100644 --- a/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py +++ b/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py @@ -50,15 +50,14 @@ from hydra.core.config_store import ConfigStore -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.lazy_config import LazyDict - from cosmos_framework.configs.base.experiment.sft.models.super_model_config import SUPER_MODEL_CONFIG from cosmos_framework.data.vfm.joint_dataloader import ( PackingDataLoader, RankPartitionedDataLoader, ) from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.utils.lazy_config import LazyDict cs = ConfigStore.instance() @@ -255,6 +254,10 @@ dataset=L(get_sft_dataset)( append_duration_fps_timestamps=True, append_resolution_info=True, + # Structured-JSON captions are long; raise the token budget so + # the loader does not truncate them (see sft_dataset.py + # _MAX_NUM_TOKENS). 2048 covers the example set (measured max ~1790). + max_num_tokens=2048, caption_suffix="", cfg_dropout_keep_metadata=False, cfg_dropout_rate=0.1, diff --git a/cosmos_framework/data/vfm/local_datasets/sft_dataset.py b/cosmos_framework/data/vfm/local_datasets/sft_dataset.py index 556310c..ffd9bf4 100644 --- a/cosmos_framework/data/vfm/local_datasets/sft_dataset.py +++ b/cosmos_framework/data/vfm/local_datasets/sft_dataset.py @@ -16,9 +16,6 @@ import numpy as np import torch -from cosmos_framework.utils.flags import INTERNAL -from cosmos_framework.utils.lazy_config import instantiate as lazy_instantiate -from cosmos_framework.utils import log from cosmos_framework.data.vfm.local_datasets.helper import ( client_config, download_from_s3, @@ -29,7 +26,11 @@ ) from cosmos_framework.data.vfm.sequence_packing import SequencePlan, add_special_tokens from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY, caption_json_to_prompt from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import tokenize_caption +from cosmos_framework.utils import log +from cosmos_framework.utils.flags import INTERNAL +from cosmos_framework.utils.lazy_config import instantiate as lazy_instantiate _MAX_NUM_TOKENS = 1024 _DURATION_TEMPLATE = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS." @@ -61,6 +62,38 @@ CAPTION_WEIGHTS = list(CAPTION_TYPES_AND_WEIGHTS.values()) +def _select_caption(t2w_window: dict) -> tuple[str, str, bool] | None: + """Pick a window's caption: ``(caption_key, caption_text, used_structured_json)``. + + Priority: ``caption_json`` (structured — the default training target) → + ``qwen3_32b_rewrite-dense`` → ``caption`` (dense backup) → a weighted-random + ``CAPTION_TYPES`` style. A structured-JSON caption (a dict, or a value under + ``caption_json``) is serialised verbatim so the training prompt is byte-identical + to the inference prompt; it must NOT receive the dense prose period-normalisation, + which would append a stray ``.`` after the closing ``}``. Returns ``None`` when + the window has no known caption key. + """ + if CAPTION_JSON_KEY in t2w_window: + caption_key = CAPTION_JSON_KEY + elif "qwen3_32b_rewrite-dense" in t2w_window: + caption_key = "qwen3_32b_rewrite-dense" + elif "caption" in t2w_window: + caption_key = "caption" + else: + available_types = [ct for ct in CAPTION_TYPES if ct in t2w_window] + if not available_types: + return None + available_weights = [CAPTION_TYPES_AND_WEIGHTS[ct] for ct in available_types] + caption_key = random.choices(available_types, weights=available_weights, k=1)[0] + + raw = t2w_window[caption_key] + if isinstance(raw, dict): + return caption_key, caption_json_to_prompt(raw), True + if caption_key == CAPTION_JSON_KEY: + return caption_key, str(raw).strip(), True + return caption_key, raw.strip().rstrip(".") + ".", False + + class SFTDataset(torch.utils.data.IterableDataset): """Dataset for loading SFT video clips with captions from JSONL metadata on S3.""" @@ -75,6 +108,7 @@ def __init__( tokenizer_config: Optional[Any] = None, cfg_dropout_rate: float = 0.0, use_system_prompt: bool = False, + max_num_tokens: int = _MAX_NUM_TOKENS, append_duration_fps_timestamps: bool = True, append_resolution_info: bool = True, cfg_dropout_keep_metadata: bool = False, @@ -100,6 +134,7 @@ def __init__( self.tokenizer_config = tokenizer_config self.cfg_dropout_rate = cfg_dropout_rate self.use_system_prompt = use_system_prompt + self.max_num_tokens = max_num_tokens self.append_duration_fps_timestamps = append_duration_fps_timestamps self.append_resolution_info = append_resolution_info self.cfg_dropout_keep_metadata = cfg_dropout_keep_metadata @@ -135,9 +170,9 @@ def _tokenize_caption(self, caption: str) -> tuple[list[int], str]: is_video=True, use_system_prompt=self.use_system_prompt, ) - if len(text_ids) > _MAX_NUM_TOKENS: - log.warning(f"Text ids are too long, truncating: {len(text_ids)} > {_MAX_NUM_TOKENS}") - text_ids = text_ids[:_MAX_NUM_TOKENS] + if len(text_ids) > self.max_num_tokens: + log.warning(f"Text ids are too long, truncating: {len(text_ids)} > {self.max_num_tokens}") + text_ids = text_ids[: self.max_num_tokens] return text_ids, caption def process_one_sample(self, metadata: dict) -> dict | None: @@ -248,22 +283,14 @@ def process_one_sample(self, metadata: dict) -> dict | None: # image_size: [target_h, target_w, orig_h, orig_w] in pixel space, for the model to crop the video image_size = torch.tensor([target_h, target_w, target_h, target_w], dtype=torch.float32) - available_types = [ct for ct in CAPTION_TYPES if ct in t2w_window] - if "qwen3_32b_rewrite-dense" in t2w_window: - caption_key = "qwen3_32b_rewrite-dense" - elif "caption" in t2w_window: - caption_key = "caption" - elif available_types: - available_weights = [CAPTION_TYPES_AND_WEIGHTS[ct] for ct in available_types] - caption_key = random.choices(available_types, weights=available_weights, k=1)[0] - else: + selected = _select_caption(t2w_window) + if selected is None: log.warning( f"No known caption key found in t2w_window for sample {metadata['uuid']}. " f"Keys: {list(t2w_window)}. Skipping sample." ) return None - caption = t2w_window[caption_key] - caption = caption.strip().rstrip(".") + "." + caption_key, caption, used_structured_json = selected num_decoded_frames = video.shape[1] cond_fps = fps if self.conditioning_fps < 0 else self.conditioning_fps @@ -271,7 +298,7 @@ def process_one_sample(self, metadata: dict) -> dict | None: noise_factor = np.exp(np.random.randn() * self.conditioning_fps_noise_std) cond_fps = cond_fps * noise_factor - if self.caption_suffix: + if self.caption_suffix and not used_structured_json: caption = (caption + " " + self.caption_suffix).strip() # CFG dropout: when cfg_dropout_keep_metadata is True, dropout fires @@ -281,11 +308,14 @@ def process_one_sample(self, metadata: dict) -> dict | None: if random.random() < self.cfg_dropout_rate: caption = "" - if self.append_duration_fps_timestamps: + # Structured-JSON captions already carry duration/fps/resolution inside the + # JSON, so skip the natural-language metadata suffixes for them. This also + # makes the training prompt byte-match the inference prompt. + if self.append_duration_fps_timestamps and not used_structured_json: duration = num_decoded_frames / cond_fps suffix = _DURATION_TEMPLATE.format(duration=duration, fps=cond_fps) caption = caption + " " + suffix - if self.append_resolution_info: + if self.append_resolution_info and not used_structured_json: suffix = _RESOLUTION_TEMPLATE.format(height=target_h, width=target_w) caption = caption + " " + suffix caption = caption.strip() @@ -552,6 +582,7 @@ def get_sft_dataset( tokenizer_config: Optional[Any] = None, cfg_dropout_rate: float = 0.1, use_system_prompt: bool = False, + max_num_tokens: int = _MAX_NUM_TOKENS, append_duration_fps_timestamps: bool = True, append_resolution_info: bool = True, cfg_dropout_keep_metadata: bool = False, @@ -668,6 +699,7 @@ def get_sft_dataset( tokenizer_config=tokenizer_config, cfg_dropout_rate=cfg_dropout_rate, use_system_prompt=use_system_prompt, + max_num_tokens=max_num_tokens, append_duration_fps_timestamps=append_duration_fps_timestamps, append_resolution_info=append_resolution_info, cfg_dropout_keep_metadata=cfg_dropout_keep_metadata, diff --git a/cosmos_framework/data/vfm/local_datasets/sft_dataset_caption_test.py b/cosmos_framework/data/vfm/local_datasets/sft_dataset_caption_test.py new file mode 100644 index 0000000..6909221 --- /dev/null +++ b/cosmos_framework/data/vfm/local_datasets/sft_dataset_caption_test.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Tests for the SFT loader's caption selection / JSON-vs-dense normalization.""" + +import json + +from cosmos_framework.data.vfm.local_datasets.sft_dataset import _select_caption +from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY + + +def test_caption_json_dict_serialized_verbatim_no_trailing_period(): + cj = {"background_setting": "kitchen", "fps": 5} + key, text, used_json = _select_caption({CAPTION_JSON_KEY: cj}) + assert key == CAPTION_JSON_KEY and used_json is True + assert not text.endswith(".") # MUST NOT append a stray '.' after '}' + assert text.endswith("}") + assert json.loads(text) == cj + + +def test_caption_json_priority_over_dense(): + cj = {"background_setting": "x"} + key, text, used_json = _select_caption({CAPTION_JSON_KEY: cj, "caption": "dense backup"}) + assert key == CAPTION_JSON_KEY and used_json is True + + +def test_caption_json_as_preserialized_string(): + key, text, used_json = _select_caption({CAPTION_JSON_KEY: '{"a": 1} '}) + assert key == CAPTION_JSON_KEY and used_json is True + assert text == '{"a": 1}' # stripped, no period + + +def test_dense_caption_gets_terminal_period(): + key, text, used_json = _select_caption({"caption": "a robot arm moves"}) + assert key == "caption" and used_json is False + assert text == "a robot arm moves." + + +def test_dense_caption_period_not_doubled(): + _, text, _ = _select_caption({"caption": "ends with period."}) + assert text == "ends with period." + + +def test_rewrite_dense_key_priority_over_generic_caption(): + key, _, used_json = _select_caption({"qwen3_32b_rewrite-dense": "x", "caption": "y"}) + assert key == "qwen3_32b_rewrite-dense" and used_json is False + + +def test_weighted_caption_types_fallback(): + key, text, used_json = _select_caption({"qwen3_235b_dense": "some dense caption"}) + assert key == "qwen3_235b_dense" and used_json is False + assert text.endswith(".") + + +def test_no_known_caption_key_returns_none(): + assert _select_caption({"start_frame": 0, "end_frame": 84}) is None diff --git a/cosmos_framework/inference/_test.py b/cosmos_framework/inference/_test.py index 569545b..9dbe0d6 100644 --- a/cosmos_framework/inference/_test.py +++ b/cosmos_framework/inference/_test.py @@ -327,13 +327,13 @@ def __call__(self, runner: ScriptRunner, cfg: ScriptConfig) -> dict[str, str]: base_checkpoint_name="Cosmos3-Nano", config_file="cosmos3/configs/experiment/vision_sft_nano.yaml", job_name="vision_sft_nano", - dataset_name="nvidia/bridge-v2-subset-synthetic-captions", + dataset_name="nvidia/BridgeData2-Subset-Synthetic-Captions", ), "vision_super": SftGetEnv( base_checkpoint_name="Cosmos3-Super", config_file="cosmos3/configs/experiment/vision_sft_super.yaml", job_name="vision_sft_super", - dataset_name="nvidia/bridge-v2-subset-synthetic-captions", + dataset_name="nvidia/BridgeData2-Subset-Synthetic-Captions", ), } _DEFAULT_SFT_NAME = "vision" diff --git a/cosmos_framework/inference/common/checkpoints.py b/cosmos_framework/inference/common/checkpoints.py index fde8283..4041cb9 100644 --- a/cosmos_framework/inference/common/checkpoints.py +++ b/cosmos_framework/inference/common/checkpoints.py @@ -299,11 +299,11 @@ class DatasetConfig(pydantic.BaseModel): DATASETS = { - "nvidia/bridge-v2-subset-synthetic-captions": DatasetConfig( + "nvidia/BridgeData2-Subset-Synthetic-Captions": DatasetConfig( hf=CheckpointDirHf( repository_type=RepositoryType.DATASET, - repository="nvidia/bridge-v2-subset-synthetic-captions", - revision="46468e12ac0dd36901e9e3240d4fc7620942b5d7", + repository="nvidia/BridgeData2-Subset-Synthetic-Captions", + revision="40d018ac1c1a2a4b9734f17fdb21f3d933c49a01", subdirectory="sft_dataset_bridge", ), ), diff --git a/cosmos_framework/inference/defaults/video_captioner.txt b/cosmos_framework/inference/defaults/video_captioner.txt index fcefbe1..0aac655 100644 --- a/cosmos_framework/inference/defaults/video_captioner.txt +++ b/cosmos_framework/inference/defaults/video_captioner.txt @@ -1,61 +1,90 @@ -You are an expert video captioner for a text-to-video model. Your task is to watch the provided video frames and produce a dense, multi-sentence narrative caption that faithfully describes what you see. +You are an expert video captioner for a text-to-video model. Your task is to watch the provided video frames and produce (1) a structured JSON analysis and (2) a dense, multi-sentence narrative caption that faithfully describes what you see. -To ensure the final caption contains the correct density of visual, cinematic, and temporal details, you must complete this task in two phases. +To ensure the outputs contain the correct density of visual, cinematic, and temporal detail, you must complete this task in two phases. --- -### PHASE 1: SCENE ANALYSIS (JSON DRAFT) -First, analyze the video frames and fill out a strict JSON schema describing what you observe. Output this JSON inside `` XML tags. +### PHASE 1: SCENE ANALYSIS (STRUCTURED JSON) +First, analyze the video frames and fill out the strict JSON schema below. Output this JSON inside `` XML tags. This is the canonical structured prompt format the model consumes, so populate every field you can observe; use "" (or [] for lists) when something is genuinely not visible. Do NOT invent details. Do NOT include the `resolution`, `aspect_ratio`, `duration`, or `fps` fields — those are filled in automatically from the video file. + +For all time-bearing fields (`actions[].time`, `segments[].time_range`) use `M:SS-M:SS` (e.g. `0:00-0:08`) measured from the start of the clip. JSON Schema to complete: { - "short_description": "Concise summary of subjects, actions, and setting.", "subjects": [ { - "description": "Detailed visual description, posture, colors", - "appearance_details": "Accessories, markings, logos", - "relationship": "Relation to other subjects", - "location": "Placement (e.g., center, left foreground)", - "relative_size": "e.g., small, large within frame", - "orientation": "e.g., facing left, profile", - "pose": "Body position", - "action": "Main action", - "clothing": "Attire, colors, footwear (if human)", - "expression": "Facial expression (if human)", - "gender": "Apparent gender presentation (if human)", - "age": "Apparent age group (if human)", - "skin_tone_and_texture": "Apparent skin tone (if human)" + "description": "Full visual description: appearance, posture, colors, identifying features", + "appearance_details": "Accessories, markings, logos, distinguishing features", + "relationship": "How this subject relates to other subjects or the scene", + "location": "Placement in frame (e.g., center foreground, left background)", + "relative_size": "Small / Medium / Large within frame", + "orientation": "Direction the subject faces relative to camera", + "pose": "Body position and posture", + "action": "Main action (brief)", + "state_changes": "How pose/action changes over time; 'No significant change.' if static", + "clothing": "Attire, colors, footwear; '' if non-human or not visible", + "expression": "Facial expression; '' if non-human or not visible", + "gender": "'Male' / 'Female' / 'Unknown'; '' if non-human", + "age": "Age category (e.g., Child, Young adult, Adult, Elderly); '' if non-human", + "skin_tone_and_texture": "Apparent skin tone; '' if non-human", + "facial_features": "Notable facial features; '' if non-human or not visible", + "number_of_subjects": 1, + "number_of_arms": 0, + "number_of_legs": 0 } ], - "background_setting": "Environment type, key structures, scenery, furniture.", + "background_setting": "Environment type, key structures, scenery, furniture", "lighting": { - "conditions": "e.g., bright daylight, dim indoor", - "direction": "e.g., front-lit, backlit", - "shadows": "Presence and quality of shadows" + "conditions": "e.g., bright daylight, dim indoor, studio lighting", + "direction": "e.g., front-lit, backlit, side-lit from right", + "shadows": "Presence and quality of shadows", + "illumination_effect": "Overall effect of the lighting on the scene" }, "aesthetics": { "composition": "e.g., rule of thirds, centered", - "color_scheme": "Dominant colors, contrasts", - "mood_atmosphere": "Visual mood" + "color_scheme": "Dominant colors and contrasts", + "mood_atmosphere": "Visual mood in short phrases", + "patterns": "Notable repeating visual patterns; '' if none" }, "cinematography": { - "camera_motion": "e.g., static, slow pan left, tracking", - "framing": "e.g., close-up, wide shot", - "camera_angle": "e.g., eye-level, low angle", - "depth_of_field": "e.g., shallow, deep", + "camera_motion": "e.g., static, slow pan left, tracking shot", + "framing": "e.g., close-up, medium shot, wide shot", + "camera_angle": "e.g., eye-level, low angle, overhead", + "depth_of_field": "'Shallow', 'Deep', or 'Uniform'", "focus": "Where the focus lies", "lens_focal_length": "e.g., wide-angle, telephoto" }, - "style_medium": "e.g., live-action video, 3D animation", - "artistic_style": "e.g., minimalist, highly realistic", - "context": "e.g., vlog, product demo, cinematic shot", - "actions": ["Array of visually observable events in chronological order"], - "text_renders": ["Array of visible text, location, size, color, font"], - "temporal_structure": "Organization over time (e.g., single continuous shot, loop)" + "style_medium": "e.g., live-action video, 3D animation, CGI", + "artistic_style": "e.g., realistic, cinematic, documentary", + "context": "Scene context or use case (brief)", + "actions": [ + { "time": "M:SS-M:SS", "description": "What happens in this timed interval" } + ], + "text_and_signage_elements": [ + { + "text": "Visible text content", + "category": "one of 'physical_in_scene', 'scene_sign', 'ui_text', 'logo', 'label'", + "appearance": "Font, color, size, style", + "spatial_temporal": "Position in scene and when visible", + "context": "Purpose or meaning of the text" + } + ], + "segments": [ + { + "segment_index": 0, + "time_range": "M:SS-M:SS", + "description": "What happens in this segment", + "key_changes": "Notable changes within the segment", + "camera": "Camera behavior in this segment" + } + ], + "transitions": ["Transition description between segments; [] if single continuous shot"], + "temporal_caption": "A temporally coherent, beat-by-beat narrative of the whole clip from start to end", + "audio_description": "Natural-language description of likely audio (speech, music, ambient, effects); '' if unknown" } --- ### PHASE 2: DENSE NARRATIVE REWRITE -After drafting the JSON, generate the final output. This must be a dense, multi-sentence caption derived STRICTLY from your scene analysis. Put this final output inside `` XML tags. +After drafting the JSON, generate the final caption. This must be a dense, multi-sentence narrative derived STRICTLY from your scene analysis. Put this final output inside `` XML tags. Rules for the Final Narrative: - Length & Format: Write exactly ONE coherent paragraph consisting of several full sentences. Do not use bullet points or lists. diff --git a/cosmos_framework/inference/structured_caption.py b/cosmos_framework/inference/structured_caption.py new file mode 100644 index 0000000..0e276a3 --- /dev/null +++ b/cosmos_framework/inference/structured_caption.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Canonical structured-JSON caption: schema, robust parsing, and assembly. + +The Cosmos3 model's native text-prompt format is structured JSON (see +``docs/prompt_upsampling.md``). This module is the single source of truth for +that format on the *captioning / training* side: + +* :data:`CAPTION_JSON_KEY` — the JSONL / ``t2w_window`` key under which the + structured caption object is stored (preferred over the dense ``caption``). +* :class:`StructuredCaption` — a permissive pydantic model mirroring + ``inference/prompting_templates/external_api/t2v_i2v_video_json_schema.json``. +* :func:`parse_structured_caption` — robustly extract the Phase-1 + ```` JSON object from a VLM response. +* :func:`assemble_caption_json` — combine the Phase-1 draft, the polished + Phase-2 dense narrative (stored as ``temporal_caption``), and the clip's real + media fields into a single validated caption object. + +The model is intentionally permissive (every field optional, ``extra="allow"``) +so that partial or slightly-off VLM output still round-trips instead of being +dropped; the goal is structural validation, not rejection. +""" + +import json +import re +from typing import Any + +import pydantic +from pydantic import ConfigDict + +# Key used in the SFT JSONL ``t2w_windows[]`` entries and recognised by the SFT +# loader (sft_dataset.py) as the highest-priority caption. Kept here so the +# captioner, the JSONL converter, and the loader cannot drift apart. +CAPTION_JSON_KEY = "caption_json" + +_PERMISSIVE = ConfigDict(extra="allow") + + +class _Base(pydantic.BaseModel): + model_config = _PERMISSIVE + + +class Subject(_Base): + description: str | None = None + appearance_details: str | None = None + relationship: str | None = None + location: str | None = None + relative_size: str | None = None + orientation: str | None = None + pose: str | None = None + action: str | None = None + state_changes: str | None = None + clothing: str | None = None + expression: str | None = None + gender: str | None = None + age: str | None = None + skin_tone_and_texture: str | None = None + facial_features: str | None = None + number_of_subjects: int | None = None + number_of_arms: int | None = None + number_of_legs: int | None = None + + +class Lighting(_Base): + conditions: str | None = None + direction: str | None = None + shadows: str | None = None + illumination_effect: str | None = None + + +class Aesthetics(_Base): + composition: str | None = None + color_scheme: str | None = None + mood_atmosphere: str | None = None + patterns: str | None = None + + +class Cinematography(_Base): + camera_motion: str | None = None + framing: str | None = None + camera_angle: str | None = None + depth_of_field: str | None = None + focus: str | None = None + lens_focal_length: str | None = None + + +class Action(_Base): + time: str | None = None + description: str | None = None + + +class TextElement(_Base): + text: str | None = None + category: str | None = None + appearance: str | None = None + spatial_temporal: str | None = None + context: str | None = None + + +class Segment(_Base): + segment_index: int | None = None + time_range: str | None = None + description: str | None = None + key_changes: str | None = None + camera: str | None = None + + +class Resolution(_Base): + H: int | None = None + W: int | None = None + + +class StructuredCaption(_Base): + """Permissive mirror of the external-API T2V/I2V JSON schema.""" + + subjects: list[Subject] | None = None + background_setting: str | None = None + lighting: Lighting | None = None + aesthetics: Aesthetics | None = None + cinematography: Cinematography | None = None + style_medium: str | None = None + artistic_style: str | None = None + context: str | None = None + actions: list[Action] | None = None + text_and_signage_elements: list[TextElement] | None = None + segments: list[Segment] | None = None + transitions: list[str] | None = None + temporal_caption: str | None = None + audio_description: str | None = None + resolution: Resolution | None = None + aspect_ratio: str | None = None + duration: str | None = None + fps: int | None = None + + +def extract_xml_tag(text: str, tag: str) -> str | None: + """Return the inner text of ``...`` (DOTALL), or ``None``.""" + match = re.search(rf"<{tag}>\s*(.*?)\s*", text, re.DOTALL) + return match.group(1).strip() if match else None + + +def _strip_code_fences(text: str) -> str: + """Strip a leading ```json / ``` fence and trailing ``` if present.""" + cleaned = text.strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + return cleaned.strip() + + +def _first_json_object(text: str) -> str | None: + """Return the first balanced ``{...}`` block in ``text``, or ``None``. + + Brace-counting fallback for when the model wraps the JSON in prose without + fences/tags. Ignores braces inside double-quoted strings. + """ + start = text.find("{") + if start < 0: + return None + depth = 0 + in_str = False + escaped = False + for i in range(start, len(text)): + ch = text[i] + if in_str: + if escaped: + escaped = False + elif ch == "\\": + escaped = True + elif ch == '"': + in_str = False + continue + if ch == '"': + in_str = True + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return text[start : i + 1] + return None + + +def parse_structured_caption(text: str) -> dict | None: + """Extract the Phase-1 ```` JSON object from a VLM response. + + Resolution order, each tolerant of ```` ```json ```` fences: + + 1. The ```` XML block. + 2. The whole response (if it is itself a JSON object). + 3. The first balanced ``{...}`` block anywhere in the response. + + Returns the parsed ``dict`` on success, or ``None`` if no valid JSON object + can be recovered (the caller should retry). + """ + candidates: list[str] = [] + tagged = extract_xml_tag(text, "scene_draft") + if tagged is not None: + candidates.append(tagged) + candidates.append(text) + + for candidate in candidates: + cleaned = _strip_code_fences(candidate) + for blob in (cleaned, _first_json_object(cleaned)): + if not blob: + continue + try: + parsed = json.loads(blob) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + return None + + +def aspect_ratio_str(width: int, height: int) -> str: + """Reduce ``width``/``height`` to a ``"W,H"`` ratio string (e.g. ``"1,1"``).""" + from math import gcd + + if width <= 0 or height <= 0: + return "" + g = gcd(int(width), int(height)) or 1 + return f"{int(width) // g},{int(height) // g}" + + +def media_fields_from_metadata(meta: dict) -> dict: + """Build the caption's media fields from :func:`probe_video_metadata` output. + + Uses the clip's *actual* values (not the canonical generation enums): the + enums constrain the upsampler's generation params, not ground-truth captions. + """ + width, height = int(meta["width"]), int(meta["height"]) + return { + "resolution": {"H": height, "W": width}, + "aspect_ratio": aspect_ratio_str(width, height), + "duration": f"{round(float(meta['duration']))}s", + "fps": int(round(float(meta["fps"]))), + } + + +def assemble_caption_json(scene_draft: dict, final_prompt: str, media: dict) -> dict: + """Assemble the final caption object and validate it. + + * ``temporal_caption`` is set to the polished Phase-2 ``final_prompt`` (this + is what keeps the dense narrative available *inside* the JSON and equal to + ``caption.txt``), overriding any draft value from Phase 1. + * ``media`` (from :func:`media_fields_from_metadata`) is merged in. + + Returns a normalised ``dict`` (None-valued fields dropped, types coerced). + Raises ``pydantic.ValidationError`` if the structure is unusable. + """ + data: dict[str, Any] = dict(scene_draft) + data["temporal_caption"] = (final_prompt or "").strip() + data.update(media) + model = StructuredCaption.model_validate(data) + return model.model_dump(exclude_none=True, mode="json") + + +def caption_json_to_prompt(caption_json: dict) -> str: + """Serialise a caption object to the compact JSON string fed to the model. + + Single source of truth for how a structured caption becomes model text, so + training (sft_dataset.py) and inference prompts use byte-identical encoding. + """ + return json.dumps(caption_json, ensure_ascii=False) diff --git a/cosmos_framework/inference/structured_caption_test.py b/cosmos_framework/inference/structured_caption_test.py new file mode 100644 index 0000000..7193829 --- /dev/null +++ b/cosmos_framework/inference/structured_caption_test.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Tests for the structured-JSON caption schema, parsing, and assembly.""" + +import json + +import pytest + +from cosmos_framework.inference.structured_caption import ( + CAPTION_JSON_KEY, + aspect_ratio_str, + assemble_caption_json, + caption_json_to_prompt, + extract_xml_tag, + media_fields_from_metadata, + parse_structured_caption, +) + + +def test_caption_json_key_is_stable(): + assert CAPTION_JSON_KEY == "caption_json" + + +@pytest.mark.parametrize( + "w,h,expected", + [(256, 256, "1,1"), (1920, 1080, "16,9"), (1080, 1920, "9,16"), (640, 480, "4,3"), (0, 10, "")], +) +def test_aspect_ratio_str(w, h, expected): + assert aspect_ratio_str(w, h) == expected + + +def test_extract_xml_tag_multiline(): + text = "\nA cat\nsits.\n" + assert extract_xml_tag(text, "final_prompt") == "A cat\nsits." + assert extract_xml_tag(text, "missing") is None + + +def test_parse_scene_draft_in_tags_with_fences(): + resp = ( + "preamble\n\n```json\n" + '{"subjects": [{"description": "arm"}], "background_setting": "kitchen"}\n' + "```\n\nAn arm." + ) + sd = parse_structured_caption(resp) + assert sd["background_setting"] == "kitchen" + assert sd["subjects"][0]["description"] == "arm" + + +def test_parse_raw_object_without_tags(): + assert parse_structured_caption('{"background_setting": "x"}')["background_setting"] == "x" + + +def test_parse_object_embedded_in_prose_with_brace_in_string(): + # The brace-matcher must ignore braces inside quoted strings. + sd = parse_structured_caption('Result: {"background_setting": "a } brace", "fps": 5} end') + assert sd["background_setting"] == "a } brace" + assert sd["fps"] == 5 + + +def test_parse_returns_none_for_garbage(): + assert parse_structured_caption("there is no json here") is None + assert parse_structured_caption("") is None + + +def test_media_fields_from_metadata_uses_actual_values(): + media = media_fields_from_metadata({"width": 256, "height": 256, "duration": 17.0, "fps": 5}) + assert media == {"resolution": {"H": 256, "W": 256}, "aspect_ratio": "1,1", "duration": "17s", "fps": 5} + + +def test_assemble_sets_temporal_caption_and_media(): + scene_draft = {"subjects": [{"description": "arm"}], "background_setting": "kitchen"} + media = media_fields_from_metadata({"width": 256, "height": 256, "duration": 17.0, "fps": 5}) + cj = assemble_caption_json(scene_draft, " An arm moves. ", media) + assert cj["temporal_caption"] == "An arm moves." # stripped, overrides any draft value + assert cj["resolution"] == {"H": 256, "W": 256} + assert cj["duration"] == "17s" and cj["fps"] == 5 + assert cj["background_setting"] == "kitchen" + + +def test_assemble_overrides_draft_temporal_caption(): + scene_draft = {"temporal_caption": "draft timeline", "background_setting": "x"} + cj = assemble_caption_json(scene_draft, "final dense", {}) + assert cj["temporal_caption"] == "final dense" + + +def test_assemble_drops_none_and_preserves_extras(): + # extra (non-schema) keys must survive; None-valued fields are dropped. + scene_draft = {"background_setting": "x", "short_description": "extra field", "lighting": None} + cj = assemble_caption_json(scene_draft, "d", {}) + assert cj["short_description"] == "extra field" + assert "lighting" not in cj + + +def test_caption_json_to_prompt_is_compact_and_roundtrips(): + cj = {"background_setting": "café", "fps": 5} + prompt = caption_json_to_prompt(cj) + assert ", " not in prompt or ": " in prompt # compact separators + assert "café" in prompt # ensure_ascii=False keeps unicode + assert json.loads(prompt) == cj diff --git a/cosmos_framework/scripts/caption_from_video.py b/cosmos_framework/scripts/caption_from_video.py index 287f07e..9096d85 100644 --- a/cosmos_framework/scripts/caption_from_video.py +++ b/cosmos_framework/scripts/caption_from_video.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""Generate dense narrative captions from video files using a Vision-Language Model. +"""Generate structured-JSON and dense narrative captions from video files using a VLM. Each video is passed directly to a VLM server via a ``video_url`` content part using a ``file://`` path. A structured prompt template guides the VLM through -a two-phase captioning process (scene analysis → dense narrative rewrite). +a two-phase captioning process (Phase 1: structured-JSON scene analysis → +Phase 2: dense narrative rewrite). Both outputs are persisted: ``caption.json`` +(the canonical structured caption, with the dense narrative embedded as +``temporal_caption`` and the clip's real media fields) and ``caption.txt`` (the +dense narrative on its own). The VLM server must support the OpenAI chat-completions API with vision and must be started with ``--allowed-local-media-path`` pointing to the root of @@ -31,7 +35,7 @@ """ import asyncio -import re +import json from pathlib import Path from typing import Annotated @@ -42,6 +46,13 @@ from cosmos_framework.inference.args import OmniSampleOverrides from cosmos_framework.inference.common.args import VIDEO_EXTENSIONS +from cosmos_framework.inference.structured_caption import ( + assemble_caption_json, + extract_xml_tag, + media_fields_from_metadata, + parse_structured_caption, +) +from cosmos_framework.scripts.video_metadata import probe_video_metadata from cosmos_framework.utils import log _PACKAGE_DIR = Path(__file__).parents[1].absolute() @@ -49,9 +60,11 @@ class Args(pydantic.BaseModel): input_files: Annotated[list[Path] | None, tyro.conf.arg(aliases=("-i",))] = None - """Path to input sample argument files (JSON/JSONL). - Each entry should have at least 'name' and 'vision_path' fields. - Mutually exclusive with --video.""" + """Path to input manifest files (JSON/JSONL). + Each entry needs a 'vision_path' (a local path or an http(s)/data URL) and may + include 'name' and a 'media' dict (resolution/aspect_ratio/duration/fps) — the + latter is used as the caption's media fields when the video is a remote URL that + ffprobe cannot read locally. Mutually exclusive with --video.""" video: Annotated[Path | None, tyro.conf.arg(aliases=("-v",))] = None """Path to a single video file or a directory of videos. @@ -69,6 +82,8 @@ class Args(pydantic.BaseModel): """Maximum number of concurrent requests to the API.""" max_retries: int = 5 """Maximum number of retries for each request.""" + timeout: float = 600.0 + """Per-request client timeout in seconds; a hung request fails after this and is retried.""" prompt_template_path: Path | None = None """Path to a custom prompt template. Defaults to the built-in video_captioner.txt.""" @@ -77,31 +92,31 @@ class Args(pydantic.BaseModel): """If True, save raw API responses for debugging.""" -def _extract_xml_tag(text: str, tag: str) -> str | None: - pattern = rf"<{tag}>\s*(.*?)\s*" - match = re.search(pattern, text, re.DOTALL) - if match: - return match.group(1).strip() - return None +def _is_remote_ref(ref: str) -> bool: + """True if ``ref`` is something the server fetches itself (URL / data URI).""" + return "://" in ref or ref.startswith("data:") -def _build_vlm_messages( - video_path: Path, - prompt_template: str, -) -> list[dict]: - """Build an OpenAI-compatible multimodal message with a video file URL + text prompt. +def _video_url(video_ref: str) -> str: + """Map a local path or remote ref to the ``video_url`` string the server receives. - The vLLM server must be started with ``--allowed-local-media-path`` so it - can read the video directly from the shared filesystem. + Remote refs (``http(s)://`` or ``data:``) are passed through untouched, so the + server fetches them itself — this is what makes captioning work against a remote + VLM endpoint. Local paths become ``file://`` URLs, which require a local server + started with ``--allowed-local-media-path``. """ + if _is_remote_ref(video_ref): + return video_ref + return f"file://{Path(video_ref).absolute()}" + + +def _build_vlm_messages(video_ref: str, prompt_template: str) -> list[dict]: + """Build an OpenAI-compatible multimodal message with a video + text prompt.""" return [ { "role": "user", "content": [ - { - "type": "video_url", - "video_url": {"url": f"file://{video_path.absolute()}"}, - }, + {"type": "video_url", "video_url": {"url": _video_url(video_ref)}}, {"type": "text", "text": prompt_template}, ], } @@ -112,13 +127,14 @@ async def _process_single( args: Args, client: openai.AsyncOpenAI, name: str, - video_path: Path, + video_ref: str, + media_override: dict | None, prompt_template: str, ) -> bool: assert args.model output_dir = args.output_dir / name - messages = _build_vlm_messages(video_path, prompt_template) + messages = _build_vlm_messages(video_ref, prompt_template) for i_retry in range(args.max_retries): try: @@ -147,22 +163,56 @@ async def _process_single( continue text = choice.message.content.strip() - final_prompt = _extract_xml_tag(text, "final_prompt") + final_prompt = extract_xml_tag(text, "final_prompt") if final_prompt is None: log.warning(f"[{i_retry + 1}/{args.max_retries}] Failed to extract final prompt for {name}") continue + scene_draft = parse_structured_caption(text) + if scene_draft is None: + log.warning(f"[{i_retry + 1}/{args.max_retries}] Failed to parse scene_draft JSON for {name}") + continue + + # Media fields: prefer a manifest-provided override; else ffprobe a local + # file; else leave empty (e.g. a remote URL ffprobe cannot read). + if media_override is not None: + media = media_override + elif not _is_remote_ref(video_ref): + try: + media = media_fields_from_metadata(probe_video_metadata(video_ref)) + except Exception as e: # noqa: BLE001 - degrade gracefully, keep the caption + log.warning(f"ffprobe failed for {name}: {e}; writing caption_json without media fields") + media = {} + else: + media = {} + + try: + caption_json = assemble_caption_json(scene_draft, final_prompt, media) + except pydantic.ValidationError as e: + log.warning(f"[{i_retry + 1}/{args.max_retries}] caption_json failed validation for {name}: {e}") + continue + output_dir.mkdir(parents=True, exist_ok=True) sample_overrides = OmniSampleOverrides( name=name, prompt=final_prompt, - vision_path=str(video_path), + vision_path=video_ref, output_dir=output_dir, ) (output_dir / "sample_args.json").write_text(sample_overrides.model_dump_json()) - (output_dir / "caption.txt").write_text(final_prompt) + (output_dir / "caption.json").write_text(json.dumps(caption_json, indent=2, ensure_ascii=False)) + + # Advisory: the SFT loader truncates very long prompts (see _MAX_NUM_TOKENS + # in sft_dataset.py). ~4 chars/token is a rough guide; warn if the serialized + # JSON looks large so it can be checked against the recipe's max_num_tokens. + approx_tokens = len(json.dumps(caption_json, ensure_ascii=False)) // 4 + if approx_tokens > 1024: + log.warning( + f"{name}: structured caption is ~{approx_tokens} tokens (rough estimate); " + "ensure the SFT recipe's max_num_tokens covers it to avoid truncation." + ) return True log.warning(f"Failed to get caption for {name}") @@ -174,36 +224,60 @@ async def _process_with_semaphore( client: openai.AsyncOpenAI, semaphore: asyncio.Semaphore, name: str, - video_path: Path, + video_ref: str, + media_override: dict | None, prompt_template: str, ) -> bool: async with semaphore: - return await _process_single(args, client, name, video_path, prompt_template) + return await _process_single(args, client, name, video_ref, media_override, prompt_template) -def _collect_video_items(args: Args) -> list[tuple[str, Path]]: - """Return a list of (name, video_path) pairs from the CLI arguments.""" - items: list[tuple[str, Path]] = [] +def _read_manifest_entries(input_files: list[Path]) -> list[tuple[str, str, dict | None]]: + """Parse ``-i`` JSON/JSONL manifests into ``(name, video_ref, media)`` tuples. - if args.input_files: - sample_overrides_list = OmniSampleOverrides.from_files(args.input_files) - for s in sample_overrides_list: - if not s.vision_path: - log.warning(f"Skipping '{s.name}': no vision_path") + Each entry must have a ``vision_path`` (a local path or an ``http(s)``/``data`` + URL) and may carry an optional ``name`` and an optional ``media`` dict (the + structured caption's media fields: resolution/aspect_ratio/duration/fps). The + ``media`` override lets remote-URL videos — which ffprobe cannot read — still + get accurate media fields. + """ + items: list[tuple[str, str, dict | None]] = [] + for path in input_files: + text = path.read_text() + if path.suffix == ".jsonl": + entries = [json.loads(line) for line in text.splitlines() if line.strip()] + else: + data = json.loads(text) + entries = data if isinstance(data, list) else [data] + for e in entries: + vp = e.get("vision_path") + name = e.get("name") + if not vp: + log.warning(f"Skipping entry with no vision_path: {name or '?'}") continue - vp = Path(s.vision_path) - if vp.suffix.lower() not in VIDEO_EXTENSIONS: - log.warning(f"Skipping '{s.name}': vision_path is not a video ({vp.suffix})") + if Path(vp).suffix.lower() not in VIDEO_EXTENSIONS: + log.warning(f"Skipping '{name or vp}': vision_path is not a video ({Path(vp).suffix})") continue - items.append((s.name or vp.stem, vp)) + items.append((name or Path(vp).stem, vp, e.get("media"))) + return items + +def _collect_video_items(args: Args) -> list[tuple[str, str, dict | None]]: + """Return ``(name, video_ref, media_override)`` items from the CLI arguments. + + ``video_ref`` is a local filesystem path or a remote URL (``http(s)``/``data``). + """ + items: list[tuple[str, str, dict | None]] = [] + + if args.input_files: + items = _read_manifest_entries(args.input_files) elif args.video: if args.video.is_dir(): for vp in sorted(args.video.iterdir()): if vp.suffix.lower() in VIDEO_EXTENSIONS: - items.append((vp.stem, vp)) + items.append((vp.stem, str(vp), None)) elif args.video.is_file(): - items.append((args.video.stem, args.video)) + items.append((args.video.stem, str(args.video), None)) else: raise FileNotFoundError(f"Video path does not exist: {args.video}") @@ -219,14 +293,14 @@ async def caption_from_video(args: Args): if args.prompt_template_path: prompt_template = args.prompt_template_path.read_text() else: - prompt_template = (_PACKAGE_DIR / "defaults/video_captioner.txt").read_text() + prompt_template = (_PACKAGE_DIR / "inference/defaults/video_captioner.txt").read_text() items = _collect_video_items(args) client = openai.AsyncOpenAI( api_key="EMPTY", base_url=args.server, - timeout=3600, + timeout=args.timeout, ) if not args.model: models = await client.models.list() @@ -241,10 +315,11 @@ async def caption_from_video(args: Args): client=client, semaphore=semaphore, name=name, - video_path=video_path, + video_ref=video_ref, + media_override=media, prompt_template=prompt_template, ) - for name, video_path in items + for name, video_ref, media in items ] n_success = 0 for result in tqdm(asyncio.as_completed(tasks), desc="Captioning", total=len(tasks)): diff --git a/cosmos_framework/scripts/caption_from_video_test.py b/cosmos_framework/scripts/caption_from_video_test.py new file mode 100644 index 0000000..88dc007 --- /dev/null +++ b/cosmos_framework/scripts/caption_from_video_test.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Tests for caption_from_video input handling (remote URLs + manifest media).""" + +import json + +from cosmos_framework.scripts.caption_from_video import ( + _PACKAGE_DIR, + _build_vlm_messages, + _is_remote_ref, + _read_manifest_entries, + _video_url, +) + + +def test_default_template_path_resolves(): + # Guards against the default prompt-template path regressing (it must point at + # inference/defaults/, not cosmos_framework/defaults/). + assert (_PACKAGE_DIR / "inference/defaults/video_captioner.txt").is_file() + + +def test_is_remote_ref(): + assert _is_remote_ref("https://h/a.mp4") + assert _is_remote_ref("http://h/a.mp4") + assert _is_remote_ref("data:video/mp4;base64,AAAA") + assert not _is_remote_ref("/abs/a.mp4") + assert not _is_remote_ref("videos/a.mp4") + + +def test_video_url_passthrough_and_file(): + assert _video_url("https://h/a.mp4") == "https://h/a.mp4" + assert _video_url("/abs/a.mp4") == "file:///abs/a.mp4" + assert _video_url("data:video/mp4;base64,AB") == "data:video/mp4;base64,AB" + + +def test_build_messages_uses_remote_url_verbatim(): + msgs = _build_vlm_messages("https://h/a.mp4", "PROMPT") + content = msgs[0]["content"] + assert content[0]["video_url"]["url"] == "https://h/a.mp4" + assert content[1]["text"] == "PROMPT" + + +def test_read_manifest_jsonl_with_url_and_media(tmp_path): + media = {"resolution": {"H": 256, "W": 256}, "aspect_ratio": "1,1", "duration": "17s", "fps": 5} + p = tmp_path / "m.jsonl" + p.write_text( + json.dumps({"name": "ep0", "vision_path": "https://h/ep0.mp4", "media": media}) + + "\n" + + json.dumps({"vision_path": "https://h/ep1.mp4"}) # no name -> derive stem; no media + + "\n" + ) + items = _read_manifest_entries([p]) + assert items[0] == ("ep0", "https://h/ep0.mp4", media) + assert items[1] == ("ep1", "https://h/ep1.mp4", None) + + +def test_read_manifest_skips_bad_entries(tmp_path): + p = tmp_path / "m.jsonl" + p.write_text( + json.dumps({"name": "novp"}) # missing vision_path + + "\n" + + json.dumps({"vision_path": "https://h/notavideo.txt"}) # wrong suffix + + "\n" + + json.dumps({"vision_path": "https://h/good.mp4"}) + + "\n" + ) + items = _read_manifest_entries([p]) + assert items == [("good", "https://h/good.mp4", None)] + + +def test_read_manifest_json_object_and_list(tmp_path): + obj = tmp_path / "o.json" + obj.write_text(json.dumps({"vision_path": "/local/a.mp4"})) + lst = tmp_path / "l.json" + lst.write_text(json.dumps([{"vision_path": "/local/b.mp4"}, {"vision_path": "/local/c.mp4"}])) + assert _read_manifest_entries([obj]) == [("a", "/local/a.mp4", None)] + assert [i[0] for i in _read_manifest_entries([lst])] == ["b", "c"] diff --git a/cosmos_framework/scripts/captions_to_sft_jsonl.py b/cosmos_framework/scripts/captions_to_sft_jsonl.py index 357a31d..1d2e9b9 100644 --- a/cosmos_framework/scripts/captions_to_sft_jsonl.py +++ b/cosmos_framework/scripts/captions_to_sft_jsonl.py @@ -5,10 +5,26 @@ The SFT dataset loader (sft_dataset.py) expects each JSONL line to have: uuid, duration, width, height, vision_path, t2w_windows -where t2w_windows is a list of dicts with start_frame, end_frame, and a -caption field. The default key is "caption", which sft_dataset.py -recognises as a generic fallback. Videos longer than 61 s are filtered -by the loader, so they are skipped here with a warning. +where t2w_windows is a list of dicts with start_frame, end_frame, temporal_interval +and a caption. This converter emits **both** caption representations per window: + +* ``caption_json`` — the canonical structured-JSON caption object (read from each + clip's ``caption.json``). The loader prefers this and trains on it by default. +* ``caption`` — the dense narrative string (read from ``caption.txt``), kept as the + backup the loader falls back to when ``caption_json`` is absent. + +If a clip has no ``caption.json`` (e.g. produced by an older captioner), the row is +written dense-only, exactly as before. + +Filters mirror what training actually consumes so dataset counts match: + +* clips longer than 61 s are dropped (matches the loader's hard cap); +* windows shorter than ``max(61, num_video_frames)`` frames are dropped. Pass + ``--num-video-frames`` to match your training recipe. The default (-1) applies + only the loader's metadata minimum of 61 frames, matching the example recipe + (``num_video_frames=-1``) so short example clips (~85 frames) are kept. + +A sibling ``.summary.json`` records kept/dropped counts per reason. Usage ----- @@ -17,7 +33,14 @@ --videos-dir outputs/videos \ -o outputs/my_dataset.jsonl - # With a custom caption key (default: caption): + # Match a recipe that decodes a fixed number of frames per window: + python -m cosmos_framework.scripts.captions_to_sft_jsonl \ + --captions-dir outputs/captions \ + --videos-dir outputs/videos \ + -o outputs/my_dataset.jsonl \ + --num-video-frames 93 + + # With a custom dense caption key (default: caption): python -m cosmos_framework.scripts.captions_to_sft_jsonl \ --captions-dir outputs/captions \ --videos-dir outputs/videos \ @@ -26,15 +49,19 @@ """ import json -import subprocess +import os import sys +from collections import Counter from pathlib import Path from typing import Annotated import tyro +from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY +from cosmos_framework.scripts.video_metadata import probe_video_metadata + _MAX_DURATION = 61.0 # seconds; matches hard-coded limit in sft_dataset.py -_MIN_FRAMES = 61 # matches min_frames=61 in get_sft_dataset() +_MIN_FRAMES = 61 # matches the metadata min_frames=61 in get_sft_dataset() _VIDEO_EXTENSIONS = (".mp4", ".mov", ".avi", ".mkv", ".webm") @@ -46,48 +73,16 @@ def _find_video(videos_dir: Path, name: str) -> Path | None: return None -def _get_video_metadata(video_path: Path) -> dict: - """Return fps, duration, width, height, total_frames via ffprobe.""" - cmd = [ - "ffprobe", - "-v", - "quiet", - "-print_format", - "json", - "-show_streams", - "-show_format", - str(video_path), - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise RuntimeError(f"ffprobe failed for {video_path}: {result.stderr}") - data = json.loads(result.stdout) - - video_stream = next( - (s for s in data["streams"] if s["codec_type"] == "video"), - None, - ) - if video_stream is None: - raise RuntimeError(f"No video stream found in {video_path}") - - fps_str = video_stream.get("avg_frame_rate", "30/1") - fps_num, fps_den = map(int, fps_str.split("/")) - fps = fps_num / max(fps_den, 1) - - duration = float(data["format"]["duration"]) - width = video_stream["width"] - height = video_stream["height"] - - # nb_frames may be absent; fall back to duration * fps - total_frames = int(video_stream.get("nb_frames") or round(duration * fps)) - - return { - "fps": fps, - "duration": duration, - "width": width, - "height": height, - "total_frames": total_frames, - } +def _relativize_vision_path(vision_path: str, output_jsonl: Path) -> str: + """Rewrite ``vision_path`` relative to the output JSONL's directory. + + The SFT loader resolves relative paths against the JSONL's directory, which + survives moving the dataset to a different mount/container. URIs containing + ``://`` (e.g. ``s3://bucket/key``) pass through unchanged. + """ + if "://" in vision_path: + return vision_path + return os.path.relpath(vision_path, start=output_jsonl.parent) def main( @@ -97,85 +92,134 @@ def main( videos_dir: Annotated[Path, tyro.conf.arg(help="Directory containing video files named ..")], output: Annotated[Path, tyro.conf.arg(aliases=("-o",), help="Output JSONL path.")], caption_key: str = "caption", + num_video_frames: Annotated[ + int, + tyro.conf.arg( + help="Decoded frames per window in your training recipe; windows shorter than " + "max(61, this) are dropped. -1 (default) applies only the 61-frame metadata " + "minimum, matching the example recipe." + ), + ] = -1, + min_short_edge: Annotated[ + int, tyro.conf.arg(help="Drop clips whose shortest spatial edge is below this value. 0 disables.") + ] = 0, ) -> None: - """Build an SFT JSONL from caption.txt files and a videos directory.""" + """Build an SFT JSONL (caption_json + dense caption) from caption dirs and videos.""" caption_files = sorted(captions_dir.glob("*/caption.txt")) + # Also accept dirs that only have caption.json (no caption.txt). + json_only = sorted( + p.parent / "caption.txt" for p in captions_dir.glob("*/caption.json") if not (p.parent / "caption.txt").exists() + ) + caption_files = sorted(set(caption_files) | set(json_only)) if not caption_files: - print(f"No caption.txt files found under {captions_dir}", file=sys.stderr) + print(f"No caption.txt / caption.json files found under {captions_dir}", file=sys.stderr) sys.exit(1) + effective_min_frames = _MIN_FRAMES if num_video_frames <= 0 else max(_MIN_FRAMES, num_video_frames) + records = [] - skipped = 0 + drops: Counter[str] = Counter() for caption_path in caption_files: name = caption_path.parent.name - caption = caption_path.read_text().strip() - - if not caption: - print(f" SKIP {name}: empty caption.txt") - skipped += 1 + dense = caption_path.read_text().strip() if caption_path.exists() else "" + + caption_json_path = caption_path.parent / "caption.json" + caption_json = None + if caption_json_path.exists(): + try: + caption_json = json.loads(caption_json_path.read_text()) + except json.JSONDecodeError as e: + print(f" WARN {name}: caption.json is not valid JSON ({e}); using dense only") + # Fall back to the JSON's temporal_caption for the dense backup if needed. + if not dense and isinstance(caption_json, dict): + dense = str(caption_json.get("temporal_caption", "")).strip() + + if not dense and caption_json is None: + print(f" SKIP {name}: no caption content") + drops["empty_caption"] += 1 continue video_path = _find_video(videos_dir, name) if video_path is None: print(f" SKIP {name}: no video found in {videos_dir} for name '{name}'") - skipped += 1 + drops["missing_video"] += 1 continue try: - meta = _get_video_metadata(video_path) + meta = probe_video_metadata(video_path) except Exception as e: print(f" SKIP {name}: ffprobe error — {e}") - skipped += 1 + drops["ffprobe_error"] += 1 continue if meta["duration"] > _MAX_DURATION: - print( - f" SKIP {name}: duration {meta['duration']:.1f}s > {_MAX_DURATION}s " - "(sft_dataset.py would filter this out)" - ) - skipped += 1 + print(f" SKIP {name}: duration {meta['duration']:.1f}s > {_MAX_DURATION}s") + drops["duration_too_long"] += 1 continue - if meta["total_frames"] < _MIN_FRAMES: - print( - f" SKIP {name}: only {meta['total_frames']} frames < {_MIN_FRAMES} " - "(sft_dataset.py would filter this out)" - ) - skipped += 1 + if meta["total_frames"] < effective_min_frames: + print(f" SKIP {name}: only {meta['total_frames']} frames < {effective_min_frames}") + drops["too_few_frames"] += 1 continue - try: - vision_path = str(video_path.resolve().relative_to(videos_dir.resolve().parent)) - except ValueError: - vision_path = str(video_path) + if min_short_edge > 0 and min(meta["width"], meta["height"]) < min_short_edge: + print(f" SKIP {name}: short edge {min(meta['width'], meta['height'])} < {min_short_edge}") + drops["short_edge_too_small"] += 1 + continue + + window: dict = { + "start_frame": 0, + "end_frame": meta["total_frames"] - 1, + "temporal_interval": 1, + } + if caption_json is not None: + window[CAPTION_JSON_KEY] = caption_json # PRIMARY (structured) + if dense: + window[caption_key] = dense # BACKUP (dense) record = { "uuid": name, "duration": meta["duration"], "width": meta["width"], "height": meta["height"], - "vision_path": vision_path, - "t2w_windows": [ - { - "start_frame": 0, - "end_frame": meta["total_frames"] - 1, - "temporal_interval": 1, - caption_key: caption, - } - ], + "vision_path": _relativize_vision_path(str(video_path), output), + "t2w_windows": [window], } records.append(record) - print(f" OK {name}: {meta['duration']:.1f}s, {meta['total_frames']} frames, {meta['width']}x{meta['height']}") + kind = "json+dense" if caption_json is not None else "dense" + print(f" OK {name}: {meta['duration']:.1f}s, {meta['total_frames']} frames, {kind}") output.parent.mkdir(parents=True, exist_ok=True) with output.open("w") as f: for record in records: - f.write(json.dumps(record) + "\n") + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + summary = { + "captions_dir": str(captions_dir), + "videos_dir": str(videos_dir), + "output_jsonl": str(output), + "records_kept": len(records), + "records_with_caption_json": sum(1 for r in records if CAPTION_JSON_KEY in r["t2w_windows"][0]), + "records_dropped": sum(drops.values()), + "drops_by_reason": dict(drops), + "filters": { + "max_duration_s": _MAX_DURATION, + "min_window_frames": effective_min_frames, + "min_short_edge": min_short_edge, + "num_video_frames": num_video_frames, + }, + } + summary_path = output.with_suffix(output.suffix + ".summary.json") + summary_path.write_text(json.dumps(summary, indent=2) + "\n") print(f"\nWrote {len(records)} records → {output}") - if skipped: - print(f"Skipped {skipped} videos") + print(f" with caption_json: {summary['records_with_caption_json']}") + if drops: + print("Drops by reason:") + for reason, count in sorted(drops.items(), key=lambda kv: (-kv[1], kv[0])): + print(f" {reason}: {count}") + print(f"Summary: {summary_path}") if not records: print("ERROR: No valid records written.", file=sys.stderr) sys.exit(1) diff --git a/cosmos_framework/scripts/captions_to_sft_jsonl_test.py b/cosmos_framework/scripts/captions_to_sft_jsonl_test.py new file mode 100644 index 0000000..b00f530 --- /dev/null +++ b/cosmos_framework/scripts/captions_to_sft_jsonl_test.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Tests for captions_to_sft_jsonl (caption_json + dense emission, filters, summary).""" + +import json + +import pytest + +from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY +from cosmos_framework.scripts import captions_to_sft_jsonl as mod + + +def _meta(width=256, height=256, duration=17.0, fps=5.0, total_frames=85): + return {"width": width, "height": height, "duration": duration, "fps": fps, "total_frames": total_frames} + + +def _make_clip(captions_dir, videos_dir, name, dense="A robot arm.", caption_json=None): + d = captions_dir / name + d.mkdir(parents=True) + if dense is not None: + (d / "caption.txt").write_text(dense) + if caption_json is not None: + (d / "caption.json").write_text(json.dumps(caption_json)) + (videos_dir / f"{name}.mp4").write_bytes(b"\x00") # presence only; ffprobe is mocked + + +def _read_jsonl(path): + return [json.loads(line) for line in path.read_text().splitlines()] + + +@pytest.fixture +def dirs(tmp_path): + captions_dir = tmp_path / "captions" + videos_dir = tmp_path / "videos" + captions_dir.mkdir() + videos_dir.mkdir() + return captions_dir, videos_dir + + +def test_emits_caption_json_and_dense(dirs, tmp_path, monkeypatch): + captions_dir, videos_dir = dirs + cj = {"background_setting": "kitchen", "temporal_caption": "A robot arm.", "fps": 5} + _make_clip(captions_dir, videos_dir, "ep0", dense="A robot arm.", caption_json=cj) + monkeypatch.setattr(mod, "probe_video_metadata", lambda p: _meta()) + + out = tmp_path / "ds.jsonl" + mod.main(captions_dir=captions_dir, videos_dir=videos_dir, output=out) + + rows = _read_jsonl(out) + assert len(rows) == 1 + window = rows[0]["t2w_windows"][0] + assert window[CAPTION_JSON_KEY] == cj # structured, as a dict object + assert window["caption"] == "A robot arm." # dense backup + assert window["start_frame"] == 0 and window["end_frame"] == 84 + assert rows[0]["uuid"] == "ep0" + # vision_path is relative to the output JSONL dir. + assert rows[0]["vision_path"] == "videos/ep0.mp4" + + summary = json.loads((tmp_path / "ds.jsonl.summary.json").read_text()) + assert summary["records_kept"] == 1 and summary["records_with_caption_json"] == 1 + + +def test_dense_only_when_no_caption_json(dirs, tmp_path, monkeypatch): + captions_dir, videos_dir = dirs + _make_clip(captions_dir, videos_dir, "ep0", dense="Only dense.", caption_json=None) + monkeypatch.setattr(mod, "probe_video_metadata", lambda p: _meta()) + + out = tmp_path / "ds.jsonl" + mod.main(captions_dir=captions_dir, videos_dir=videos_dir, output=out) + window = _read_jsonl(out)[0]["t2w_windows"][0] + assert CAPTION_JSON_KEY not in window + assert window["caption"] == "Only dense." + + +def test_drops_long_and_short_clips(dirs, tmp_path, monkeypatch): + captions_dir, videos_dir = dirs + _make_clip(captions_dir, videos_dir, "good", caption_json={"x": 1}) + _make_clip(captions_dir, videos_dir, "toolong", caption_json={"x": 1}) + _make_clip(captions_dir, videos_dir, "tooshort", caption_json={"x": 1}) + + def fake_probe(path): + if "toolong" in str(path): + return _meta(duration=99.0) + if "tooshort" in str(path): + return _meta(total_frames=10) + return _meta() + + monkeypatch.setattr(mod, "probe_video_metadata", fake_probe) + out = tmp_path / "ds.jsonl" + mod.main(captions_dir=captions_dir, videos_dir=videos_dir, output=out) + + rows = _read_jsonl(out) + assert {r["uuid"] for r in rows} == {"good"} + summary = json.loads((tmp_path / "ds.jsonl.summary.json").read_text()) + assert summary["drops_by_reason"]["duration_too_long"] == 1 + assert summary["drops_by_reason"]["too_few_frames"] == 1 + + +def test_num_video_frames_filter_drops_85_frame_clip(dirs, tmp_path, monkeypatch): + captions_dir, videos_dir = dirs + _make_clip(captions_dir, videos_dir, "ep0", caption_json={"x": 1}) + monkeypatch.setattr(mod, "probe_video_metadata", lambda p: _meta(total_frames=85)) + + out = tmp_path / "ds.jsonl" + # With num_video_frames=93, an 85-frame clip must be dropped (matches decode-time + # filtering); with the default -1 it is kept. + with pytest.raises(SystemExit): + mod.main(captions_dir=captions_dir, videos_dir=videos_dir, output=out, num_video_frames=93) + summary = json.loads((tmp_path / "ds.jsonl.summary.json").read_text()) + assert summary["drops_by_reason"]["too_few_frames"] == 1 + + +def test_caption_json_falls_back_to_temporal_caption_for_dense(dirs, tmp_path, monkeypatch): + captions_dir, videos_dir = dirs + # No caption.txt; dense should be recovered from caption.json temporal_caption. + cj = {"temporal_caption": "Recovered dense.", "fps": 5} + d = captions_dir / "ep0" + d.mkdir(parents=True) + (d / "caption.json").write_text(json.dumps(cj)) + (videos_dir / "ep0.mp4").write_bytes(b"\x00") + monkeypatch.setattr(mod, "probe_video_metadata", lambda p: _meta()) + + out = tmp_path / "ds.jsonl" + mod.main(captions_dir=captions_dir, videos_dir=videos_dir, output=out) + window = _read_jsonl(out)[0]["t2w_windows"][0] + assert window[CAPTION_JSON_KEY] == cj + assert window["caption"] == "Recovered dense." diff --git a/cosmos_framework/scripts/inference_prompts_to_json.py b/cosmos_framework/scripts/inference_prompts_to_json.py new file mode 100644 index 0000000..72e9051 --- /dev/null +++ b/cosmos_framework/scripts/inference_prompts_to_json.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Rewrite a dataset's inference-prompt JSON files to use structured-JSON prompts. + +The example dataset ships per-clip inference prompts under +``/inference_prompt{,_i2v,_v2v}/.json`` whose ``prompt`` field is a +**dense** narrative. Once the dataset carries structured-JSON captions, the +inference example should use the **same** format so it matches what the model is +trained on. This script replaces each file's ``prompt`` with the serialized +structured caption (from the clip's ``caption.json``), preserving every other +field (``name``, ``resolution``, ``aspect_ratio``, ``num_frames``, ``fps``, +``vision_path``). It is idempotent and re-runnable. + +Usage +----- + python -m cosmos_framework.scripts.inference_prompts_to_json \ + --val-dir /path/to/sft_dataset_bridge/val + + # captions live elsewhere, or only update specific variants: + python -m cosmos_framework.scripts.inference_prompts_to_json \ + --val-dir /path/to/val --captions-dir /path/to/val/captions --dry-run +""" + +import json +import sys +from pathlib import Path +from typing import Annotated + +import tyro + +from cosmos_framework.inference.structured_caption import caption_json_to_prompt + + +def main( + val_dir: Annotated[Path, tyro.conf.arg(help="Dataset split dir containing inference_prompt*/ and captions/.")], + captions_dir: Annotated[ + Path | None, tyro.conf.arg(help="Dir with /caption.json (default: /captions).") + ] = None, + inference_prompt_glob: str = "inference_prompt*", + dry_run: bool = False, +) -> None: + """Replace dense `prompt` fields with the serialized structured JSON caption.""" + captions_dir = captions_dir or (val_dir / "captions") + prompt_dirs = sorted(d for d in val_dir.glob(inference_prompt_glob) if d.is_dir()) + if not prompt_dirs: + print(f"No '{inference_prompt_glob}' directories found under {val_dir}", file=sys.stderr) + sys.exit(1) + + n_updated = 0 + n_missing_caption = 0 + n_files = 0 + + for prompt_dir in prompt_dirs: + for prompt_path in sorted(prompt_dir.glob("*.json")): + n_files += 1 + episode = prompt_path.stem + caption_json_path = captions_dir / episode / "caption.json" + if not caption_json_path.exists(): + print(f" MISS {prompt_dir.name}/{episode}: no caption.json at {caption_json_path}") + n_missing_caption += 1 + continue + + try: + caption_json = json.loads(caption_json_path.read_text()) + except json.JSONDecodeError as e: + print(f" MISS {prompt_dir.name}/{episode}: caption.json invalid ({e})") + n_missing_caption += 1 + continue + + record = json.loads(prompt_path.read_text()) + record["prompt"] = caption_json_to_prompt(caption_json) + + if dry_run: + print(f" DRY {prompt_dir.name}/{episode}: would set prompt ({len(record['prompt'])} chars)") + else: + prompt_path.write_text(json.dumps(record, indent=4, ensure_ascii=False)) + print(f" OK {prompt_dir.name}/{episode}") + n_updated += 1 + + print( + f"\n{'Would update' if dry_run else 'Updated'} {n_updated}/{n_files} prompt files " + f"across {len(prompt_dirs)} dir(s); {n_missing_caption} missing caption.json" + ) + if n_updated == 0: + print("ERROR: No prompt files updated.", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/cosmos_framework/scripts/inference_prompts_to_json_test.py b/cosmos_framework/scripts/inference_prompts_to_json_test.py new file mode 100644 index 0000000..4092663 --- /dev/null +++ b/cosmos_framework/scripts/inference_prompts_to_json_test.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Tests for inference_prompts_to_json (dense prompt -> structured-JSON prompt).""" + +import json + +import pytest + +from cosmos_framework.scripts import inference_prompts_to_json as mod + + +def _build_val(tmp_path, variants=("inference_prompt", "inference_prompt_i2v")): + val = tmp_path / "val" + cap = val / "captions" / "ep0" + cap.mkdir(parents=True) + caption_json = {"background_setting": "kitchen", "temporal_caption": "An arm.", "fps": 5} + (cap / "caption.json").write_text(json.dumps(caption_json)) + for v in variants: + d = val / v + d.mkdir(parents=True) + rec = {"name": f"{v}/ep0", "prompt": "OLD DENSE PROMPT", "resolution": "256", "fps": 5} + if v != "inference_prompt": + rec["vision_path"] = "../images/ep0.jpg" + (d / "ep0.json").write_text(json.dumps(rec)) + return val, caption_json + + +def test_replaces_prompt_with_structured_json_preserving_fields(tmp_path): + val, caption_json = _build_val(tmp_path) + mod.main(val_dir=val) + + for v in ("inference_prompt", "inference_prompt_i2v"): + rec = json.loads((val / v / "ep0.json").read_text()) + assert json.loads(rec["prompt"]) == caption_json # prompt is now the serialized JSON + assert rec["name"] == f"{v}/ep0" # preserved + assert rec["resolution"] == "256" and rec["fps"] == 5 # preserved + # i2v keeps its vision_path + assert json.loads((val / "inference_prompt_i2v" / "ep0.json").read_text())["vision_path"] == "../images/ep0.jpg" + + +def test_dry_run_does_not_modify(tmp_path): + val, _ = _build_val(tmp_path, variants=("inference_prompt",)) + before = (val / "inference_prompt" / "ep0.json").read_text() + mod.main(val_dir=val, dry_run=True) + assert (val / "inference_prompt" / "ep0.json").read_text() == before + + +def test_missing_caption_json_is_skipped(tmp_path): + val, _ = _build_val(tmp_path, variants=("inference_prompt",)) + # Add a second prompt with no matching caption.json. + rec = {"name": "inference_prompt/ep_missing", "prompt": "DENSE"} + (val / "inference_prompt" / "ep_missing.json").write_text(json.dumps(rec)) + mod.main(val_dir=val) + # ep_missing is untouched (still dense), ep0 is updated. + assert json.loads((val / "inference_prompt" / "ep_missing.json").read_text())["prompt"] == "DENSE" + assert (val / "inference_prompt" / "ep0.json").read_text() != json.dumps({"prompt": "DENSE"}) + + +def test_errors_when_no_prompt_dirs(tmp_path): + (tmp_path / "val").mkdir() + with pytest.raises(SystemExit): + mod.main(val_dir=tmp_path / "val") diff --git a/cosmos_framework/scripts/video_metadata.py b/cosmos_framework/scripts/video_metadata.py new file mode 100644 index 0000000..0f34d7b --- /dev/null +++ b/cosmos_framework/scripts/video_metadata.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Shared ffprobe-based video metadata helper for the captioning / dataset scripts. + +Returns ``fps``, ``duration`` (seconds), ``width``, ``height`` and +``total_frames`` for a video file. Used by both ``caption_from_video.py`` +(to fill the structured caption's media fields) and ``captions_to_sft_jsonl.py`` +(to build SFT JSONL rows), so the two stay consistent. +""" + +import json +import subprocess +from pathlib import Path + +_VIDEO_EXTENSIONS = (".mp4", ".mov", ".avi", ".mkv", ".webm") + + +def probe_video_metadata(video_path: str | Path) -> dict: + """Return ``{fps, duration, width, height, total_frames}`` via ffprobe. + + Raises: + RuntimeError: if ffprobe fails or the file has no video stream. + """ + video_path = str(video_path) + cmd = [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + "-show_format", + video_path, + ] + result = subprocess.run(cmd, stdin=subprocess.DEVNULL, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"ffprobe failed for {video_path}: {result.stderr}") + data = json.loads(result.stdout) + + video_stream = next( + (s for s in data.get("streams", []) if s.get("codec_type") == "video"), + None, + ) + if video_stream is None: + raise RuntimeError(f"No video stream found in {video_path}") + + fps_str = video_stream.get("avg_frame_rate") or video_stream.get("r_frame_rate") or "30/1" + fps_num, fps_den = (fps_str.split("/") + ["1"])[:2] + fps_den_f = float(fps_den) or 1.0 + fps = float(fps_num) / fps_den_f + + # Duration: prefer the container format duration, fall back to the stream's. + duration = float(data.get("format", {}).get("duration") or video_stream.get("duration") or 0.0) + + width = int(video_stream["width"]) + height = int(video_stream["height"]) + + # nb_frames may be absent; fall back to duration * fps. + total_frames = int(video_stream.get("nb_frames") or round(duration * fps)) + + return { + "fps": fps, + "duration": duration, + "width": width, + "height": height, + "total_frames": total_frames, + } diff --git a/docs/dataset_jsonl.md b/docs/dataset_jsonl.md index 632fd9c..1ee3f09 100644 --- a/docs/dataset_jsonl.md +++ b/docs/dataset_jsonl.md @@ -11,7 +11,7 @@ Prerequisites: Run inference on a single sample: ```shell -export DATASET_PATH=$(uvx hf@latest download --repo-type dataset nvidia/bridge-v2-subset-synthetic-captions --revision 46468e12ac0dd36901e9e3240d4fc7620942b5d7 --quiet)/sft_dataset_bridge +export DATASET_PATH=$(uvx hf@latest download --repo-type dataset nvidia/BridgeData2-Subset-Synthetic-Captions --revision 40d018ac1c1a2a4b9734f17fdb21f3d933c49a01 --quiet)/sft_dataset_bridge torchrun --nproc-per-node=8 -m cosmos_framework.scripts.inference \ --parallelism-preset=latency \ @@ -55,7 +55,20 @@ Each example below uses the following layout: ## Format -Example sample: +Each `t2w_window` may carry **two** caption representations: + +- **`caption_json`** — the canonical structured-JSON caption (an object). The SFT loader + prefers this and trains on it by default, serialising it to the exact JSON string the + model consumes at inference. The dense narrative is embedded inside it as + `temporal_caption`, and the clip's media fields (`resolution`, `aspect_ratio`, + `duration`, `fps`) describe the source clip. +- **`caption`** — the dense narrative string, kept as the **backup** the loader falls + back to when `caption_json` is absent. + +This keeps the post-training example aligned with inference, which also uses the +structured-JSON prompt format (see [Inference](#inference)). + +Example sample (the structured object is abbreviated for readability): ```json { @@ -69,15 +82,31 @@ Example sample: "start_frame": 0, "end_frame": 86, "temporal_interval": 1, - "caption": "A black robotic arm, featuring articulated joints and a metallic finish, extends over a white tray placed on a wooden table, manipulating small black objects that resemble beads or marbles. The arm moves with precision, grasping clusters of these objects, lifting them, and relocating them across the tray\u2019s surface in a methodical manner, often shifting them from one side to another. The background reveals an indoor workspace with visible equipment, illuminated by bright, even lighting that casts minimal shadows, emphasizing the technical nature of the scene. The camera remains static throughout, offering a medium shot that centers on the robotic arm and tray, with a slightly angled top-down perspective that highlights the contrast between the black objects, white tray, and wooden table. The robotic arm\u2019s movements are continuous and deliberate, showcasing its ability to handle and reposition the objects with accuracy, while the scene maintains a minimalist and functional aesthetic throughout." + "caption_json": { + "subjects": [ + {"description": "A black robotic arm with articulated joints and a metallic finish", "action": "grasps and relocates small black objects across a white tray"} + ], + "background_setting": "An indoor workspace with visible equipment on a wooden table", + "cinematography": {"camera_motion": "static", "framing": "medium shot", "camera_angle": "slightly angled top-down"}, + "actions": [{"time": "0:00-0:17", "description": "the arm repeatedly lifts and repositions clusters of objects"}], + "temporal_caption": "A black robotic arm, featuring articulated joints and a metallic finish, extends over a white tray ... maintaining a minimalist and functional aesthetic throughout.", + "resolution": {"H": 256, "W": 256}, + "aspect_ratio": "1,1", + "duration": "17s", + "fps": 5 + }, + "caption": "A black robotic arm, featuring articulated joints and a metallic finish, extends over a white tray placed on a wooden table, manipulating small black objects that resemble beads or marbles. The arm moves with precision, grasping clusters of these objects, lifting them, and relocating them across the tray’s surface in a methodical manner, often shifting them from one side to another. The background reveals an indoor workspace with visible equipment, illuminated by bright, even lighting that casts minimal shadows, emphasizing the technical nature of the scene. The camera remains static throughout, offering a medium shot that centers on the robotic arm and tray, with a slightly angled top-down perspective that highlights the contrast between the black objects, white tray, and wooden table. The robotic arm’s movements are continuous and deliberate, showcasing its ability to handle and reposition the objects with accuracy, while the scene maintains a minimalist and functional aesthetic throughout." } ] } ``` +> Older datasets that contain only the dense `caption` field still work unchanged — the +> loader simply falls back to it. + ## Video Captioning -If you have video sources and would like to synthesize caption annotations to build video–text pairs for training, follow this section for data preprocessing. The script sends each video directly to a Reasoner (vision-language model), which analyzes the visual content and produces a dense narrative caption following a two-phase process (scene analysis → narrative rewrite) — the same format expected by the Cosmos3 training pipeline. +If you have video sources and would like to synthesize caption annotations to build video–text pairs for training, follow this section for data preprocessing. The script sends each video directly to a Reasoner (vision-language model), which analyzes the visual content via a two-phase process (Phase 1: structured-JSON scene analysis → Phase 2: dense narrative rewrite) and saves **both** outputs: a `caption.json` (the canonical structured caption that the Cosmos3 training pipeline and inference consume, with the dense narrative embedded as `temporal_caption`) and a `caption.txt` (the dense narrative on its own). The captioning prompt template is available at [`cosmos_framework/inference/defaults/video_captioner.txt`](../cosmos_framework/inference/defaults/video_captioner.txt). @@ -128,7 +157,7 @@ Options: | `--prompt_template_path` | built-in | Path to a custom prompt template | | `--debug` | `False` | Save raw API responses | -Each video produces an output directory containing `caption.txt` (the plain-text caption) and `sample_args.json` (metadata). +Each video produces an output directory containing `caption.json` (the canonical structured caption), `caption.txt` (the dense narrative), and `sample_args.json` (metadata). ### Create Dataset @@ -149,4 +178,15 @@ python -m cosmos_framework.scripts.captions_to_sft_jsonl \ -o outputs/sft_dataset/train/video_dataset_file.jsonl ``` -It will create a dataset JSONL file containing captions and their corresponding paths to video files. +Each JSONL line contains both `caption_json` (structured, preferred for training) and `caption` (dense, backup) for every window, plus the corresponding video path. The converter mirrors the loader's silent filters (clips longer than 61 s, and windows shorter than `max(61, --num-video-frames)` frames) so the kept count matches what training will actually consume — pass `--num-video-frames` to match your recipe (the example recipe uses `-1`, i.e. all frames, so the default keeps short example clips). A sibling `.summary.json` records the kept count and per-reason drop counts. + +#### Align the inference prompts + +To make the validation inference prompts use the **same** structured-JSON format as training, rewrite each `val/inference_prompt{,_i2v,_v2v}/.json` file's `prompt` field with the clip's structured caption: + +```shell +python -m cosmos_framework.scripts.inference_prompts_to_json \ + --val-dir outputs/sft_dataset/val +``` + +This reads `val/captions//caption.json` and replaces the (dense) `prompt` with the serialized structured JSON, preserving `resolution`, `aspect_ratio`, `num_frames`, `fps`, and `vision_path`. Pass `--dry-run` to preview. diff --git a/docs/training.md b/docs/training.md index 951fb2a..bf9bf90 100644 --- a/docs/training.md +++ b/docs/training.md @@ -41,7 +41,7 @@ Select one of the following recipes:
Vision SFT (Cosmos3-Nano) -T2V/I2V/V2V SFT on [nvidia/bridge-v2-subset-synthetic-captions](https://huggingface.co/datasets/nvidia/bridge-v2-subset-synthetic-captions/tree/main). `$DATASET_PATH` should be the directory containing `train/video_dataset_file.jsonl`. +T2V/I2V/V2V SFT on [nvidia/BridgeData2-Subset-Synthetic-Captions](https://huggingface.co/datasets/nvidia/BridgeData2-Subset-Synthetic-Captions/tree/main). `$DATASET_PATH` should be the directory containing `train/video_dataset_file.jsonl`. Each clip carries a structured-JSON caption (`caption_json`) — the model's native prompt format — which the SFT loader trains on by default (the dense narrative is kept as a backup), so training stays aligned with [Inference](./dataset_jsonl.md#inference); see [JSONL Dataset → Format](./dataset_jsonl.md#format). Launch shell: `examples/launch_sft_vision_nano.sh` @@ -49,9 +49,9 @@ Launch shell: `examples/launch_sft_vision_nano.sh` BASE_CHECKPOINT_NAME=Cosmos3-Nano # Defaults match the launcher (see Step 3 → Option A to override). -uvx hf@latest download --repo-type dataset nvidia/bridge-v2-subset-synthetic-captions \ - --revision 46468e12ac0dd36901e9e3240d4fc7620942b5d7 \ - --local-dir examples/data/bridge-v2-subset-synthetic-captions --quiet +uvx hf@latest download --repo-type dataset nvidia/BridgeData2-Subset-Synthetic-Captions \ + --revision 40d018ac1c1a2a4b9734f17fdb21f3d933c49a01 \ + --local-dir examples/data/BridgeData2-Subset-Synthetic-Captions --quiet uvx hf@latest download Wan-AI/Wan2.2-TI2V-5B Wan2.2_VAE.pth \ --local-dir examples/checkpoints/wan22_vae --quiet ``` @@ -68,9 +68,9 @@ Launch shell: `examples/launch_sft_vision_super.sh` BASE_CHECKPOINT_NAME=Cosmos3-Super # Defaults match the launcher (see Step 3 → Option A to override). -uvx hf@latest download --repo-type dataset nvidia/bridge-v2-subset-synthetic-captions \ - --revision 46468e12ac0dd36901e9e3240d4fc7620942b5d7 \ - --local-dir examples/data/bridge-v2-subset-synthetic-captions --quiet +uvx hf@latest download --repo-type dataset nvidia/BridgeData2-Subset-Synthetic-Captions \ + --revision 40d018ac1c1a2a4b9734f17fdb21f3d933c49a01 \ + --local-dir examples/data/BridgeData2-Subset-Synthetic-Captions --quiet uvx hf@latest download Wan-AI/Wan2.2-TI2V-5B Wan2.2_VAE.pth \ --local-dir examples/checkpoints/wan22_vae --quiet ``` @@ -154,12 +154,12 @@ bash examples/launch_sft_vision_nano.sh Each launcher's default paths come from the `DATASET_PATH` + `BASE_CHECKPOINT_PATH` defaults declared at the top of its `.sh` (each uses `: "${VAR:=…}"` so any value you `export` in the shell before launching wins over the default): -| Launch shell | Post-Training Task | Default $DATASET_PATH (under examples/data/) | Default $BASE_CHECKPOINT_PATH (under examples/checkpoints/) | -| ------------------------------------- | ------------------ | -------------------------------------------------------- | ----------------------------------------------------------- | -| `launch_sft_vision_nano.sh` | Generator SFT | `bridge-v2-subset-synthetic-captions/sft_dataset_bridge` | `Cosmos3-Nano` | -| `launch_sft_vision_super.sh` | Generator SFT | `bridge-v2-subset-synthetic-captions/sft_dataset_bridge` | `Cosmos3-Super` | -| `launch_sft_llava_ov.sh` | Reasoner SFT | (none; dataset streams from HF Hub) | (none; backbone fetched at startup) | -| `launch_sft_videophy2_nano.sh` | Reasoner SFT | (none; set `VIDEOPHYSICS_ROOT` env) | (none; set `VLM_SAFETENSORS_PATH` env) | +| Launch shell | Post-Training Task | Default $DATASET_PATH (under examples/data/) | Default $BASE_CHECKPOINT_PATH (under examples/checkpoints/) | +| ------------------------------ | ------------------ | ---------------------------------------------------------- | ----------------------------------------------------------- | +| `launch_sft_vision_nano.sh` | Generator SFT | `BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge` | `Cosmos3-Nano` | +| `launch_sft_vision_super.sh` | Generator SFT | `BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge` | `Cosmos3-Super` | +| `launch_sft_llava_ov.sh` | Reasoner SFT | (none; dataset streams from HF Hub) | (none; backbone fetched at startup) | +| `launch_sft_videophy2_nano.sh` | Reasoner SFT | (none; set `VIDEOPHYSICS_ROOT` env) | (none; set `VLM_SAFETENSORS_PATH` env) | `WAN_VAE_PATH` defaults to `examples/checkpoints/wan22_vae/Wan2.2_VAE.pth` for every non-reasoner recipe. @@ -178,7 +178,7 @@ If you'd rather put data or checkpoints on a different filesystem (e.g. a faster ```shell # Example: data on /scratch, base DCP on /nfs/ckpts. -export DATASET_PATH=/scratch/bridge-v2-subset-synthetic-captions/sft_dataset_bridge +export DATASET_PATH=/scratch/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge export BASE_CHECKPOINT_PATH=/nfs/ckpts/Cosmos3-Nano export WAN_VAE_PATH=/nfs/ckpts/wan22_vae/Wan2.2_VAE.pth bash examples/launch_sft_vision_nano.sh @@ -207,7 +207,7 @@ Run from the repo root (the directory containing `pyproject.toml` and `examples/ TOML_FILE="examples/toml/sft_config/vision_sft_nano.toml" # Match the launcher's defaults — or substitute your own paths. -export DATASET_PATH="$PWD/examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge" +export DATASET_PATH="$PWD/examples/data/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge" export BASE_CHECKPOINT_PATH="$PWD/examples/checkpoints/Cosmos3-Nano" export WAN_VAE_PATH="$PWD/examples/checkpoints/wan22_vae/Wan2.2_VAE.pth" diff --git a/examples/launch_sft_vision_nano.sh b/examples/launch_sft_vision_nano.sh index 863ab3e..a6f725f 100755 --- a/examples/launch_sft_vision_nano.sh +++ b/examples/launch_sft_vision_nano.sh @@ -8,7 +8,7 @@ # # Optional env vars (defaults below point under examples/; override to put # data or checkpoints on a different filesystem): -# DATASET_PATH default: examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge +# DATASET_PATH default: examples/data/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge # (must contain train/video_dataset_file.jsonl) # BASE_CHECKPOINT_PATH default: examples/checkpoints/Cosmos3-Nano # WAN_VAE_PATH default: examples/checkpoints/wan22_vae/Wan2.2_VAE.pth @@ -19,7 +19,7 @@ # bash examples/launch_sft_vision_nano.sh TOML_FILE="examples/toml/sft_config/vision_sft_nano.toml" -: "${DATASET_PATH:=examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge}" +: "${DATASET_PATH:=examples/data/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge}" : "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Nano}" EXTRA_DATASET_CHECK='[[ -f "$DATASET_PATH/train/video_dataset_file.jsonl" ]] || { echo "ERROR: missing $DATASET_PATH/train/video_dataset_file.jsonl" >&2; exit 1; }' diff --git a/examples/launch_sft_vision_super.sh b/examples/launch_sft_vision_super.sh index 2909a32..f3caa5f 100755 --- a/examples/launch_sft_vision_super.sh +++ b/examples/launch_sft_vision_super.sh @@ -8,7 +8,7 @@ # # Optional env vars (defaults below point under examples/; override to put # data or checkpoints on a different filesystem): -# DATASET_PATH default: examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge +# DATASET_PATH default: examples/data/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge # (must contain train/video_dataset_file.jsonl) # BASE_CHECKPOINT_PATH default: examples/checkpoints/Cosmos3-Super # WAN_VAE_PATH default: examples/checkpoints/wan22_vae/Wan2.2_VAE.pth @@ -19,7 +19,7 @@ # bash examples/launch_sft_vision_super.sh TOML_FILE="examples/toml/sft_config/vision_sft_super.toml" -: "${DATASET_PATH:=examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge}" +: "${DATASET_PATH:=examples/data/BridgeData2-Subset-Synthetic-Captions/sft_dataset_bridge}" : "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Super}" EXTRA_DATASET_CHECK='[[ -f "$DATASET_PATH/train/video_dataset_file.jsonl" ]] || { echo "ERROR: missing $DATASET_PATH/train/video_dataset_file.jsonl" >&2; exit 1; }' diff --git a/tests/_stage_h100_inputs.sh b/tests/_stage_h100_inputs.sh index a2ddfc4..d59c34e 100755 --- a/tests/_stage_h100_inputs.sh +++ b/tests/_stage_h100_inputs.sh @@ -51,12 +51,12 @@ fi echo ">>> $(date '+%H:%M:%S') transformers=$(python -c 'import transformers; print(transformers.__version__)')" # ---------------------------------------------------------------------------- -# 1. Mixed-modality SFT dataset (bridge-v2-subset-synthetic-captions). +# 1. Mixed-modality SFT dataset (BridgeData2-Subset-Synthetic-Captions). # ---------------------------------------------------------------------------- -echo ">>> $(date '+%H:%M:%S') downloading bridge-v2-subset-synthetic-captions ..." +echo ">>> $(date '+%H:%M:%S') downloading BridgeData2-Subset-Synthetic-Captions ..." BRIDGE_ROOT=$(uvx hf@latest download --repo-type dataset \ - nvidia/bridge-v2-subset-synthetic-captions \ - --revision 46468e12ac0dd36901e9e3240d4fc7620942b5d7 \ + nvidia/BridgeData2-Subset-Synthetic-Captions \ + --revision 40d018ac1c1a2a4b9734f17fdb21f3d933c49a01 \ --quiet) DATASET_PATH="$BRIDGE_ROOT/sft_dataset_bridge" echo "DATASET_PATH=$DATASET_PATH"