Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cosmos_framework/configs/base/defaults/open_source_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@

from hydra.core.config_store import ConfigStore

from cosmos_framework.configs.base.defaults.vlm import create_qwen2_tokenizer_with_download
from cosmos_framework.data.vfm.joint_dataloader import (
PackingDataLoader,
RankPartitionedDataLoader,
)
from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset
from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.configs.base.defaults.vlm import create_qwen2_tokenizer_with_download


# ---------------------------------------------------------------------------
# Inner: SFT video dataset (matches the inline ``get_sft_dataset`` call in the
# reference YAML).
# ---------------------------------------------------------------------------


def get_sft_video_dataset(
*,
jsonl_paths: list[str],
Expand All @@ -55,6 +55,11 @@ def get_sft_video_dataset(
append_duration_fps_timestamps: bool = True,
append_resolution_info: bool = True,
use_system_prompt: bool = False,
# Structured-JSON captions are far longer than dense prose; raise the token
# budget so the loader does not truncate them mid-JSON (see sft_dataset.py
# _MAX_NUM_TOKENS). 2048 covers the example dataset (measured max ~1790 Qwen
# tokens) with margin; keep consistent with the inference prompt budget.
max_num_tokens: int = 2048,
caption_suffix: str = "",
cfg_dropout_rate: float = 0.1,
cfg_dropout_keep_metadata: bool = False,
Expand Down Expand Up @@ -85,6 +90,7 @@ def get_sft_video_dataset(
append_duration_fps_timestamps=append_duration_fps_timestamps,
append_resolution_info=append_resolution_info,
use_system_prompt=use_system_prompt,
max_num_tokens=max_num_tokens,
caption_suffix=caption_suffix,
cfg_dropout_rate=cfg_dropout_rate,
cfg_dropout_keep_metadata=cfg_dropout_keep_metadata,
Expand All @@ -103,6 +109,7 @@ def get_sft_video_dataset(
# pipeline. This is the registered config_store node.
# ---------------------------------------------------------------------------


def get_open_source_sft_dataloader(
*,
jsonl_paths: list[str] | None = None,
Expand Down Expand Up @@ -165,6 +172,7 @@ def get_open_source_sft_dataloader(
# ConfigStore registration.
# ---------------------------------------------------------------------------


def register_open_source_dataloaders() -> None:
"""Register named dataloader configs under the ``data_train`` Hydra group.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@

from hydra.core.config_store import ConfigStore

from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.lazy_config import LazyDict

from cosmos_framework.configs.base.experiment.sft.models.nano_model_config import NANO_MODEL_CONFIG
from cosmos_framework.data.vfm.joint_dataloader import (
PackingDataLoader,
RankPartitionedDataLoader,
)
from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset
from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.lazy_config import LazyDict

cs = ConfigStore.instance()

Expand Down Expand Up @@ -237,6 +236,10 @@
dataset=L(get_sft_dataset)(
append_duration_fps_timestamps=True,
append_resolution_info=True,
# Structured-JSON captions are long; raise the token budget so
# the loader does not truncate them (see sft_dataset.py
# _MAX_NUM_TOKENS). 2048 covers the example set (measured max ~1790).
max_num_tokens=2048,
caption_suffix="",
cfg_dropout_keep_metadata=False,
cfg_dropout_rate=0.1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@

from hydra.core.config_store import ConfigStore

from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.lazy_config import LazyDict

from cosmos_framework.configs.base.experiment.sft.models.super_model_config import SUPER_MODEL_CONFIG
from cosmos_framework.data.vfm.joint_dataloader import (
PackingDataLoader,
RankPartitionedDataLoader,
)
from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset
from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.lazy_config import LazyDict

cs = ConfigStore.instance()

Expand Down Expand Up @@ -255,6 +254,10 @@
dataset=L(get_sft_dataset)(
append_duration_fps_timestamps=True,
append_resolution_info=True,
# Structured-JSON captions are long; raise the token budget so
# the loader does not truncate them (see sft_dataset.py
# _MAX_NUM_TOKENS). 2048 covers the example set (measured max ~1790).
max_num_tokens=2048,
caption_suffix="",
cfg_dropout_keep_metadata=False,
cfg_dropout_rate=0.1,
Expand Down
72 changes: 52 additions & 20 deletions cosmos_framework/data/vfm/local_datasets/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
import numpy as np
import torch

from cosmos_framework.utils.flags import INTERNAL
from cosmos_framework.utils.lazy_config import instantiate as lazy_instantiate
from cosmos_framework.utils import log
from cosmos_framework.data.vfm.local_datasets.helper import (
client_config,
download_from_s3,
Expand All @@ -29,7 +26,11 @@
)
from cosmos_framework.data.vfm.sequence_packing import SequencePlan, add_special_tokens
from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO
from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY, caption_json_to_prompt
from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import tokenize_caption
from cosmos_framework.utils import log
from cosmos_framework.utils.flags import INTERNAL
from cosmos_framework.utils.lazy_config import instantiate as lazy_instantiate

_MAX_NUM_TOKENS = 1024
_DURATION_TEMPLATE = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS."
Expand Down Expand Up @@ -61,6 +62,38 @@
CAPTION_WEIGHTS = list(CAPTION_TYPES_AND_WEIGHTS.values())


def _select_caption(t2w_window: dict) -> tuple[str, str, bool] | None:
"""Pick a window's caption: ``(caption_key, caption_text, used_structured_json)``.

Priority: ``caption_json`` (structured — the default training target) →
``qwen3_32b_rewrite-dense`` → ``caption`` (dense backup) → a weighted-random
``CAPTION_TYPES`` style. A structured-JSON caption (a dict, or a value under
``caption_json``) is serialised verbatim so the training prompt is byte-identical
to the inference prompt; it must NOT receive the dense prose period-normalisation,
which would append a stray ``.`` after the closing ``}``. Returns ``None`` when
the window has no known caption key.
"""
if CAPTION_JSON_KEY in t2w_window:
caption_key = CAPTION_JSON_KEY
elif "qwen3_32b_rewrite-dense" in t2w_window:
caption_key = "qwen3_32b_rewrite-dense"
elif "caption" in t2w_window:
caption_key = "caption"
else:
available_types = [ct for ct in CAPTION_TYPES if ct in t2w_window]
if not available_types:
return None
available_weights = [CAPTION_TYPES_AND_WEIGHTS[ct] for ct in available_types]
caption_key = random.choices(available_types, weights=available_weights, k=1)[0]

raw = t2w_window[caption_key]
if isinstance(raw, dict):
return caption_key, caption_json_to_prompt(raw), True
if caption_key == CAPTION_JSON_KEY:
return caption_key, str(raw).strip(), True
return caption_key, raw.strip().rstrip(".") + ".", False


class SFTDataset(torch.utils.data.IterableDataset):
"""Dataset for loading SFT video clips with captions from JSONL metadata on S3."""

Expand All @@ -75,6 +108,7 @@ def __init__(
tokenizer_config: Optional[Any] = None,
cfg_dropout_rate: float = 0.0,
use_system_prompt: bool = False,
max_num_tokens: int = _MAX_NUM_TOKENS,
append_duration_fps_timestamps: bool = True,
append_resolution_info: bool = True,
cfg_dropout_keep_metadata: bool = False,
Expand All @@ -100,6 +134,7 @@ def __init__(
self.tokenizer_config = tokenizer_config
self.cfg_dropout_rate = cfg_dropout_rate
self.use_system_prompt = use_system_prompt
self.max_num_tokens = max_num_tokens
self.append_duration_fps_timestamps = append_duration_fps_timestamps
self.append_resolution_info = append_resolution_info
self.cfg_dropout_keep_metadata = cfg_dropout_keep_metadata
Expand Down Expand Up @@ -135,9 +170,9 @@ def _tokenize_caption(self, caption: str) -> tuple[list[int], str]:
is_video=True,
use_system_prompt=self.use_system_prompt,
)
if len(text_ids) > _MAX_NUM_TOKENS:
log.warning(f"Text ids are too long, truncating: {len(text_ids)} > {_MAX_NUM_TOKENS}")
text_ids = text_ids[:_MAX_NUM_TOKENS]
if len(text_ids) > self.max_num_tokens:
log.warning(f"Text ids are too long, truncating: {len(text_ids)} > {self.max_num_tokens}")
text_ids = text_ids[: self.max_num_tokens]
return text_ids, caption

def process_one_sample(self, metadata: dict) -> dict | None:
Expand Down Expand Up @@ -248,30 +283,22 @@ def process_one_sample(self, metadata: dict) -> dict | None:
# image_size: [target_h, target_w, orig_h, orig_w] in pixel space, for the model to crop the video
image_size = torch.tensor([target_h, target_w, target_h, target_w], dtype=torch.float32)

available_types = [ct for ct in CAPTION_TYPES if ct in t2w_window]
if "qwen3_32b_rewrite-dense" in t2w_window:
caption_key = "qwen3_32b_rewrite-dense"
elif "caption" in t2w_window:
caption_key = "caption"
elif available_types:
available_weights = [CAPTION_TYPES_AND_WEIGHTS[ct] for ct in available_types]
caption_key = random.choices(available_types, weights=available_weights, k=1)[0]
else:
selected = _select_caption(t2w_window)
if selected is None:
log.warning(
f"No known caption key found in t2w_window for sample {metadata['uuid']}. "
f"Keys: {list(t2w_window)}. Skipping sample."
)
return None
caption = t2w_window[caption_key]
caption = caption.strip().rstrip(".") + "."
caption_key, caption, used_structured_json = selected

num_decoded_frames = video.shape[1]
cond_fps = fps if self.conditioning_fps < 0 else self.conditioning_fps
if self.conditioning_fps_noise_std > 0:
noise_factor = np.exp(np.random.randn() * self.conditioning_fps_noise_std)
cond_fps = cond_fps * noise_factor

if self.caption_suffix:
if self.caption_suffix and not used_structured_json:
caption = (caption + " " + self.caption_suffix).strip()

# CFG dropout: when cfg_dropout_keep_metadata is True, dropout fires
Expand All @@ -281,11 +308,14 @@ def process_one_sample(self, metadata: dict) -> dict | None:
if random.random() < self.cfg_dropout_rate:
caption = ""

if self.append_duration_fps_timestamps:
# Structured-JSON captions already carry duration/fps/resolution inside the
# JSON, so skip the natural-language metadata suffixes for them. This also
# makes the training prompt byte-match the inference prompt.
if self.append_duration_fps_timestamps and not used_structured_json:
duration = num_decoded_frames / cond_fps
suffix = _DURATION_TEMPLATE.format(duration=duration, fps=cond_fps)
caption = caption + " " + suffix
if self.append_resolution_info:
if self.append_resolution_info and not used_structured_json:
suffix = _RESOLUTION_TEMPLATE.format(height=target_h, width=target_w)
caption = caption + " " + suffix
caption = caption.strip()
Expand Down Expand Up @@ -552,6 +582,7 @@ def get_sft_dataset(
tokenizer_config: Optional[Any] = None,
cfg_dropout_rate: float = 0.1,
use_system_prompt: bool = False,
max_num_tokens: int = _MAX_NUM_TOKENS,
append_duration_fps_timestamps: bool = True,
append_resolution_info: bool = True,
cfg_dropout_keep_metadata: bool = False,
Expand Down Expand Up @@ -668,6 +699,7 @@ def get_sft_dataset(
tokenizer_config=tokenizer_config,
cfg_dropout_rate=cfg_dropout_rate,
use_system_prompt=use_system_prompt,
max_num_tokens=max_num_tokens,
append_duration_fps_timestamps=append_duration_fps_timestamps,
append_resolution_info=append_resolution_info,
cfg_dropout_keep_metadata=cfg_dropout_keep_metadata,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1
"""Tests for the SFT loader's caption selection / JSON-vs-dense normalization."""

import json

from cosmos_framework.data.vfm.local_datasets.sft_dataset import _select_caption
from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY


def test_caption_json_dict_serialized_verbatim_no_trailing_period():
cj = {"background_setting": "kitchen", "fps": 5}
key, text, used_json = _select_caption({CAPTION_JSON_KEY: cj})
assert key == CAPTION_JSON_KEY and used_json is True
assert not text.endswith(".") # MUST NOT append a stray '.' after '}'
assert text.endswith("}")
assert json.loads(text) == cj


def test_caption_json_priority_over_dense():
cj = {"background_setting": "x"}
key, text, used_json = _select_caption({CAPTION_JSON_KEY: cj, "caption": "dense backup"})
assert key == CAPTION_JSON_KEY and used_json is True


def test_caption_json_as_preserialized_string():
key, text, used_json = _select_caption({CAPTION_JSON_KEY: '{"a": 1} '})
assert key == CAPTION_JSON_KEY and used_json is True
assert text == '{"a": 1}' # stripped, no period


def test_dense_caption_gets_terminal_period():
key, text, used_json = _select_caption({"caption": "a robot arm moves"})
assert key == "caption" and used_json is False
assert text == "a robot arm moves."


def test_dense_caption_period_not_doubled():
_, text, _ = _select_caption({"caption": "ends with period."})
assert text == "ends with period."


def test_rewrite_dense_key_priority_over_generic_caption():
key, _, used_json = _select_caption({"qwen3_32b_rewrite-dense": "x", "caption": "y"})
assert key == "qwen3_32b_rewrite-dense" and used_json is False


def test_weighted_caption_types_fallback():
key, text, used_json = _select_caption({"qwen3_235b_dense": "some dense caption"})
assert key == "qwen3_235b_dense" and used_json is False
assert text.endswith(".")


def test_no_known_caption_key_returns_none():
assert _select_caption({"start_frame": 0, "end_frame": 84}) is None
4 changes: 2 additions & 2 deletions cosmos_framework/inference/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,13 @@ def __call__(self, runner: ScriptRunner, cfg: ScriptConfig) -> dict[str, str]:
base_checkpoint_name="Cosmos3-Nano",
config_file="cosmos3/configs/experiment/vision_sft_nano.yaml",
job_name="vision_sft_nano",
dataset_name="nvidia/bridge-v2-subset-synthetic-captions",
dataset_name="nvidia/BridgeData2-Subset-Synthetic-Captions",
),
"vision_super": SftGetEnv(
base_checkpoint_name="Cosmos3-Super",
config_file="cosmos3/configs/experiment/vision_sft_super.yaml",
job_name="vision_sft_super",
dataset_name="nvidia/bridge-v2-subset-synthetic-captions",
dataset_name="nvidia/BridgeData2-Subset-Synthetic-Captions",
),
}
_DEFAULT_SFT_NAME = "vision"
Expand Down
6 changes: 3 additions & 3 deletions cosmos_framework/inference/common/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,11 @@ class DatasetConfig(pydantic.BaseModel):


DATASETS = {
"nvidia/bridge-v2-subset-synthetic-captions": DatasetConfig(
"nvidia/BridgeData2-Subset-Synthetic-Captions": DatasetConfig(
hf=CheckpointDirHf(
repository_type=RepositoryType.DATASET,
repository="nvidia/bridge-v2-subset-synthetic-captions",
revision="46468e12ac0dd36901e9e3240d4fc7620942b5d7",
repository="nvidia/BridgeData2-Subset-Synthetic-Captions",
revision="40d018ac1c1a2a4b9734f17fdb21f3d933c49a01",
subdirectory="sft_dataset_bridge",
),
),
Expand Down
Loading