Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 176 additions & 29 deletions cosmos_framework/inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@
from cosmos_framework.inference.common.inference import Inference


def _load_transfer_prompt_path(path: str | Path) -> str:
"""Load a transfer prompt from a ``.json`` or plain ``.txt`` file."""
resolved = Path(path)
text = resolved.read_text()
if resolved.suffix.lower() == ".json":
return json.dumps(json.loads(text))
return text.strip()


def _load_transfer_negative_prompt_file(path: str | Path) -> str:
"""Load a JSON negative caption file for transfer inference."""
candidate = Path(path)
if not candidate.is_file():
defaults_path = PACKAGE_DIR / "defaults" / candidate.name
if defaults_path.is_file():
candidate = defaults_path
else:
raise FileNotFoundError(f"Missing negative prompt file: {path} (also checked {defaults_path})")
return json.dumps(json.loads(candidate.read_text()))


@cache
def _load_modality_defaults(model_mode: str) -> dict[str, Any]:
default_file = PACKAGE_DIR / f"defaults/{model_mode}/sample_args.json"
Expand All @@ -59,12 +80,15 @@ def _load_modality_defaults(model_mode: str) -> dict[str, Any]:
Guidance = Annotated[float, pydantic.Field(ge=0, le=7)]
GuidanceInterval = tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]
PromptUpsamplerProbability = Annotated[float, pydantic.Field(ge=0, le=1)]
ControlGuidance = Annotated[float, pydantic.Field(ge=0, le=10)]


class SamplingArgs(ArgsBase):
num_steps: pydantic.PositiveInt
guidance: Guidance
guidance_interval: GuidanceInterval | None
control_guidance: ControlGuidance = 1.0
control_guidance_interval: GuidanceInterval | None = None
normalize_cfg: bool
shift: float
sigma_max: float
Expand All @@ -79,6 +103,13 @@ class SamplingOverrides(OverridesBase):
"""Guidance scale for the diffusion model."""
guidance_interval: Training[GuidanceInterval | None] = None
"""Guidance interval for the diffusion model."""
control_guidance: Training[ControlGuidance | None] = None
"""Control-CFG scale for transfer. The control map stays in the main forward; 1.0 disables
the extra comparison forward that drops control-map vision items. Values > 1.0 blend
velocities from with-maps vs without-maps forwards on the generated clip."""
control_guidance_interval: Training[GuidanceInterval | None] = None
"""Timestep interval [lo, hi] (0–1000) in which control-CFG is applied; None applies
on every step."""
normalize_cfg: Training[bool | None] = None
"""If True, normalize the CFG output."""
shift: Training[float | None] = None
Expand All @@ -102,6 +133,8 @@ def _build_sampling(self, model_config: "OmniMoTModelConfig", sample_meta: "Samp
self.sigma_max = 0.0
return
assert self.num_steps is not None
if self.control_guidance is None:
self.control_guidance = 1.0
if SMOKE:
self.num_steps = min(self.num_steps, 1)

Expand Down Expand Up @@ -300,7 +333,11 @@ def _build_text_data(self, model_config: "OmniMoTModelConfig", sample_meta: Samp
if self.prompt is not None:
pass
elif self.prompt_path is not None:
self.prompt = self.prompt_path.read_text().strip()
transfer_self = cast("_TransferDataBase", self)
if transfer_self.transfer_hints:
self.prompt = _load_transfer_prompt_path(self.prompt_path)
else:
self.prompt = self.prompt_path.read_text().strip()
else:
self.prompt = ""

Expand Down Expand Up @@ -612,6 +649,136 @@ def _build_reasoner_data(self, model_config: "OmniMoTModelConfig", sample_meta:
raise ValueError("Reasoner inference requires a non-empty 'prompt'.")


class _TransferDataBase:
@property
def transfer_hints(self) -> dict[TransferHintKey, TransferOverrides | TransferArgs]:
# Iteration order is `TransferHintKey` enum order, not JSON-key order — keep this
# deterministic so the model sees a stable [ctrl_1, ..., ctrl_N] sequence.
return {key: getattr(self, key.value) for key in TransferHintKey if getattr(self, key.value) is not None}


class TransferDataArgs(ArgsBase, _TransferDataBase):
edge: EdgeTransferArgs | None = None
blur: BlurTransferArgs | None = None
depth: TransferArgs | None = None
seg: TransferArgs | None = None
wsm: TransferArgs | None = None
negative_prompt_file: str | None = None
"""JSON negative caption file for transfer specs (absolute path or filename under ``defaults/``)."""
num_video_frames_per_chunk: pydantic.PositiveInt | None = None
num_conditional_frames: pydantic.NonNegativeInt | None = None
max_frames: pydantic.PositiveInt | None = None
show_control_condition: bool | None = None
show_input: bool | None = None
num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None
share_vision_temporal_positions: bool | None = None


class TransferDataOverrides(OverridesBase, _TransferDataBase):
"""Transfer inference overrides — activated when at least one control hint is set."""

edge: EdgeTransferOverrides | None = None
blur: BlurTransferOverrides | None = None
depth: TransferOverrides | None = None
seg: TransferOverrides | None = None
wsm: TransferOverrides | None = None
negative_prompt_file: str | None = None
"""JSON negative caption file for transfer specs (absolute path or filename under ``defaults/``)."""
num_video_frames_per_chunk: pydantic.PositiveInt | None = None
num_conditional_frames: pydantic.NonNegativeInt | None = None
max_frames: pydantic.PositiveInt | None = None
show_control_condition: bool | None = None
show_input: bool | None = None
num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None
share_vision_temporal_positions: bool | None = None

@pydantic.model_validator(mode="after")
def _validate_transfer_hints(self) -> Self:
hint_field_names = {k.value for k in TransferHintKey}
transfer_only = [
name
for name in TransferDataOverrides.__annotations__
if name in type(self).model_fields and name not in hint_field_names
]
if any(getattr(self, f) is not None for f in transfer_only) and not self.transfer_hints:
raise ValueError(
f"transfer inference requires at least one control hint ({', '.join(k.value for k in TransferHintKey)})"
)
return self

@override
def download(self, output_dir: Path):
super().download(output_dir)
for config in self.transfer_hints.values():
assert isinstance(config, TransferOverrides)
config.download(output_dir)

_TRANSFER_SAMPLE_DEFAULTS: ClassVar[dict[str, Any]] = {
"num_video_frames_per_chunk": 93,
"num_conditional_frames": 1,
"max_frames": 5000,
"show_control_condition": False,
"show_input": False,
"num_first_chunk_conditional_frames": 0,
"share_vision_temporal_positions": True,
}
_TRANSFER_HINT_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = {
TransferHintKey.EDGE: {"preset_edge_threshold": PresetEdgeThreshold.MEDIUM},
TransferHintKey.BLUR: {"preset_blur_strength": PresetBlurStrength.MEDIUM},
}
# Tuned guidance / control_guidance per transfer task. Applied when the input JSON omits
# these fields (generic video2video sampling defaults otherwise apply).
_TRANSFER_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = {
TransferHintKey.EDGE: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0},
TransferHintKey.BLUR: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0},
TransferHintKey.DEPTH: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0},
TransferHintKey.SEG: {"guidance": 3.0, "control_guidance": 2.0, "shift": 10.0},
TransferHintKey.WSM: {
"guidance": 1.0,
"control_guidance": 3.0,
"shift": 10.0,
"num_frames": 101,
"fps": 10,
"num_video_frames_per_chunk": 101,
},
}

def _build_transfer_data(
self,
model_config: "OmniMoTModelConfig",
sample_meta: SampleMeta,
*,
user_fields: frozenset[str] | None = None,
):
self = cast("SampleDataOverrides", self)
hints = self.transfer_hints
if not hints:
return

if self.negative_prompt is None and self.negative_prompt_file is not None:
self.negative_prompt = _load_transfer_negative_prompt_file(self.negative_prompt_file)

for field, default in self._TRANSFER_SAMPLE_DEFAULTS.items():
if getattr(self, field) is None:
setattr(self, field, default)

for hint_key, config in hints.items():
for field, default in self._TRANSFER_HINT_DEFAULTS.get(hint_key, {}).items():
if getattr(config, field) is None:
setattr(config, field, default)

if self.vision_path is None and config.control_path is None:
raise ValueError(
f"transfer inference requires 'vision_path' or a pre-computed 'control_path' (hint: {hint_key})"
)

if len(hints) == 1:
hint_key = next(iter(hints))
for field, value in self._TRANSFER_DEFAULTS[hint_key].items():
if user_fields is None or field not in user_fields:
setattr(self, field, value)


class _SampleDataBase:
@property
def resolved_model_mode(self) -> ModelMode:
Expand Down Expand Up @@ -641,6 +808,7 @@ class SampleDataArgs(
SoundDataArgs,
ActionDataArgs,
ReasonerDataArgs,
TransferDataArgs,
):
model_mode: ModelMode

Expand All @@ -652,6 +820,7 @@ class SampleDataOverrides(
SoundDataOverrides,
ActionDataOverrides,
ReasonerDataOverrides,
TransferDataOverrides,
):
"""Sample data arguments for 'OmniMoTModel.generate_samples'."""

Expand Down Expand Up @@ -790,7 +959,8 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs:
else:
defaults = _load_modality_defaults(sample_meta.model_mode)
overrides = self.model_dump(exclude_none=True)
shift_configured = "shift" in overrides or defaults.get("shift") is not None
user_fields = frozenset(overrides)
shift_configured = "shift" in user_fields or defaults.get("shift") is not None
merged_data = _deep_merge(defaults, overrides)
merged_data = {k: v for k, v in merged_data.items() if k in type(self).model_fields}
merged = type(self).model_validate(merged_data)
Expand All @@ -811,6 +981,10 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs:

self._build_reasoner_data(model_config=model_config, sample_meta=sample_meta)

self._build_transfer_data(
model_config=model_config, sample_meta=sample_meta, user_fields=user_fields
)

if not shift_configured and not sample_meta.model_mode.is_reasoner:
model_size = self._VLM_MODEL_SIZE[model_config.vlm_config.model_name]
key = (model_size, self.resolution)
Expand Down Expand Up @@ -869,33 +1043,6 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs:
revision="main",
),
),
# Task-specialized Super variants published as diffusers HF checkpoints.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# s3_uri is unused for HF-backed checkpoints (kept for parity with the
# registry schema); the architecture lives in each model YAML.
"Cosmos3-Super-Image2Video": CheckpointConfig(
model_memory_bytes=MODEL_MEMORY_BYTES_BY_SIZE["32B"],
config_file=str(CONFIG_DIR / "model/Cosmos3-Super.yaml"),
s3_uri="s3://bucket1/cosmos3_vfm/cosmos3_ga_image2video/",
hf=CheckpointDirHf(
repository="nvidia/Cosmos3-Super-Image2Video",
revision="main",
),
# Self-contained checkpoint: use its bundled processor instead of
# downloading the base Cosmos3-Super repo just for the tokenizer.
vlm_processor_from_checkpoint=True,
),
"Cosmos3-Super-Text2Image": CheckpointConfig(
model_memory_bytes=MODEL_MEMORY_BYTES_BY_SIZE["32B"],
config_file=str(CONFIG_DIR / "model/Cosmos3-Super.yaml"),
s3_uri="s3://bucket1/cosmos3_vfm/cosmos3_ga_text2image/",
hf=CheckpointDirHf(
repository="nvidia/Cosmos3-Super-Text2Image",
revision="main",
),
# Self-contained checkpoint: use its bundled processor instead of
# downloading the base Cosmos3-Super repo just for the tokenizer.
vlm_processor_from_checkpoint=True,
),
}
DEFAULT_CHECKPOINT_NAME = "Cosmos3-Nano"
DEFAULT_CHECKPOINT = _CHECKPOINTS[DEFAULT_CHECKPOINT_NAME]
Expand Down
Loading