Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .agents/skills/cosmos3-inference/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this
| Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models |
| Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities |
| What parallelism preset? (latency vs throughput) | `README.md` § Inference |
| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading |
| What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments |
| What are the default parameter values? | `cosmos_framework/inference/defaults/<model_mode>/sample_args.json` (per-modality JSON) |
| How do I use custom defaults? | `docs/inference.md` § Custom Defaults |
Expand Down
1 change: 1 addition & 0 deletions .claude/skills/cosmos3-inference/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this
| Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models |
| Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities |
| What parallelism preset? (latency vs throughput) | `README.md` § Inference |
| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading |
| What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments |
| What are the default parameter values? | `cosmos_framework/inference/defaults/<model_mode>/sample_args.json` (per-modality JSON) |
| How do I use custom defaults? | `docs/inference.md` § Custom Defaults |
Expand Down
26 changes: 26 additions & 0 deletions cosmos_framework/inference/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,32 @@ def check_model_equal(actual: pydantic.BaseModel, expected: pydantic.BaseModel):
check_model_equal(OmniSetupOverrides.model_validate(args.model_dump()).build_setup(), args)


def test_offload_stages(tmp_path: Path):
def _build(**kwargs):
return OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
output_dir=tmp_path / "outputs",
**kwargs,
).build_setup()

# Default: offloading disabled.
args = _build()
assert args.offload_stages == ()

# Arena stages round-trip through build_setup.
args = _build(offload_stages=("reasoner", "generator", "vae"))
assert args.offload_stages == ("reasoner", "generator", "vae")

# Guardrail offloading is a separate flag, not an --offload-stages value.
for bad in (("guardrails",), ("bogus",)):
with pytest.raises(pydantic.ValidationError):
OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
output_dir=tmp_path / "outputs",
offload_stages=bad,
)


def test_sample_args(tmp_path: Path):
setup_args = OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
Expand Down
14 changes: 14 additions & 0 deletions cosmos_framework/inference/common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
Training = Suppress


# Single-GPU CPU-offload stages selectable via ``--offload-stages``. (Guardrail
# offloading has its own dedicated flag, ``--offload-guardrail-models``.)
OffloadStage = Literal["reasoner", "generator", "vae"]


IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
VIDEO_EXTENSIONS = [".mp4"]
MEDIA_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS
Expand Down Expand Up @@ -643,6 +648,7 @@ class SetupArgs(ABC, CheckpointArgs, ParallelismArgs, GuardrailArgs):
warmup: pydantic.NonNegativeInt
max_model_len: pydantic.PositiveInt | None
max_num_seqs: pydantic.PositiveInt | None
offload_stages: tuple[OffloadStage, ...]

# Subclass must implement these fields/methods
# ------------------------------------------------------------
Expand Down Expand Up @@ -693,6 +699,14 @@ class SetupOverrides(ABC, CheckpointOverrides, ParallelismOverrides, GuardrailOv
max_num_seqs: pydantic.PositiveInt | None = 1
"""Maximum number of sequences per batch. When set, samples are packed into
batches by number of sequences."""
offload_stages: tuple[OffloadStage, ...] = ()
"""Single-GPU CPU-offload stages. Each named component is offloaded to pinned CPU
storage and staged into one reusable GPU arena only while in use, reducing peak
GPU memory. Choices: 'reasoner' / 'generator' (the MoT towers — enabling either
runs the understanding pathway once as a prefill that caches the per-layer K/V, then
runs the denoise loop generator-only) and 'vae' (the vision tokenizer, staged around
encode/decode). Empty = off (joint path, unchanged). Single-GPU only; incompatible
with CUDA graphs. Guardrail offloading has its own flag, --offload-guardrail-models."""

def _build_setup(self):
pass
Expand Down
198 changes: 144 additions & 54 deletions cosmos_framework/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from cosmos_framework.inference.common.inference import Inference, sync_distributed_errors
from cosmos_framework.inference.common.init import get_rank, get_world_size
from cosmos_framework.inference.model import Cosmos3OmniConfig, Cosmos3OmniModel
from cosmos_framework.inference.offloading import OffloadPipeline
from cosmos_framework.inference.vision import (
build_conditioned_video_batch,
build_image_edit_batch,
Expand All @@ -50,7 +51,7 @@
from cosmos_framework.tools.visualize.video import save_img_or_video
from cosmos_framework.configs.base.defaults.compile import CompileConfig
from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig
from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel
from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel, cpu_offload_materialization
from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING
from cosmos_framework.model.vfm.upsampler.prompts import is_upsampled_prompt

Expand All @@ -59,6 +60,43 @@

UpsampleTask = Literal["t2i", "t2v", "i2v"]

# Reasoner/generator CPU-offload group names (single-GPU). The two MoT pathways
# are interleaved per layer, so the groups gather the per-layer understanding vs
# generation submodules (plus the generation-side diffusion heads) by reference.
REASONER_OFFLOAD_PART = "reasoner"
GENERATOR_OFFLOAD_PART = "generator"
VAE_OFFLOAD_PART = "vae"
# Offload stages that ride the diffusion GPU arena (``guardrails`` is handled
# separately by the guardrail runners' own CPU-offload path).
ARENA_OFFLOAD_PARTS = (REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART, VAE_OFFLOAD_PART)


def build_omni_offload_parts(model: OmniMoTModel) -> dict[str, "torch.nn.Module"]:
"""Group the in-tree Cosmos3 model into offloadable module groups.

The reasoner/generator partition is owned by the network
(:meth:`Cosmos3VFMNetwork.offload_module_groups`) so the load-time CPU
materialization and this runtime ``OffloadPipeline`` share one source of truth.
Each group is wrapped in an ``nn.ModuleList`` referencing the existing submodule
objects (no module-tree changes), and the ``vae`` group is added for the vision
tokenizer network. All groups are disjoint; callers stage only the requested
subset, and modules left out of every staged group stay GPU-resident.
"""
import torch.nn as nn

groups = model.net.offload_module_groups() # {"reasoner": [...], "generator": [...]}
parts: dict[str, nn.Module] = {name: nn.ModuleList(modules) for name, modules in groups.items()}

# The vision tokenizer's underlying network (the offloadable VAE weights).
# ``tokenizer_vision_gen.model`` may itself be a thin (non-Module) wrapper whose
# ``.model`` is the actual nn.Module (e.g. Wan2pt2VAEInterface -> WanVAE -> WanVAE_).
vae = getattr(model.tokenizer_vision_gen, "model", None)
if vae is not None and not isinstance(vae, nn.Module):
vae = getattr(vae, "model", None)
if isinstance(vae, nn.Module):
parts[VAE_OFFLOAD_PART] = vae
return parts


_BatchItem = TypeVar("_BatchItem")

Expand Down Expand Up @@ -1016,60 +1054,106 @@ def _create(cls, setup_args: SetupArgs, **kwargs: Any) -> Self:
sampler_override = setup_args.sampler
parallelism_config = cls._get_parallelism_config(setup_args)
compile_config = cls._get_compile_config(setup_args)
if setup_args.checkpoint_type == CheckpointType.DCP and setup_args.config_file_type == ConfigFileType.MODULE:
from cosmos_framework.inference.common.config import save_config
from cosmos_framework.utils.vfm.model_loader import load_model_from_checkpoint

if not setup_args.experiment:
raise ValueError("'experiment' is required")
if not setup_args.config_file:
raise ValueError("'config_file' is required")

Cosmos3OmniModel.before_load_model()
model, config = load_model_from_checkpoint(
experiment_name=setup_args.experiment,
config_file=setup_args.config_file,
checkpoint_path=setup_args.checkpoint_path,
credential_path=setup_args.credential_path or None,
parallelism_config=attrs.asdict(parallelism_config),
compile_config=attrs.asdict(compile_config),
load_ema_to_reg=setup_args.use_ema_weights,
experiment_opts=[
*setup_args.experiment_overrides,
f"model.config.rectified_flow_inference_config.scheduler_type={sampler_override}",
],
use_cache_checkpoint=setup_args.checkpoint_cache_dir is not None,
cache_checkpoint_rootdir=str(setup_args.checkpoint_cache_dir or ""),
)
model = cast("OmniMoTModel", model)
Cosmos3OmniModel.after_load_model(model)
save_config(config, setup_args.output_dir)
else:
checkpoint_path = setup_args.download_checkpoint()
if setup_args.config_file_type == ConfigFileType.MODULE:
config = None
# Two-phase materialization for single-GPU CPU offloading: build the reasoner/
# generator towers directly on CPU during model construction so they never occupy
# GPU memory (the checkpoint shards load straight into CPU tensors). No-op when
# those stages aren't requested.
net_offload_parts = tuple(
s for s in setup_args.offload_stages if s in (REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART)
)
with cpu_offload_materialization(net_offload_parts):
if setup_args.checkpoint_type == CheckpointType.DCP and setup_args.config_file_type == ConfigFileType.MODULE:
from cosmos_framework.inference.common.config import save_config
from cosmos_framework.utils.vfm.model_loader import load_model_from_checkpoint

if not setup_args.experiment:
raise ValueError("'experiment' is required")
if not setup_args.config_file:
raise ValueError("'config_file' is required")

Cosmos3OmniModel.before_load_model()
model, config = load_model_from_checkpoint(
experiment_name=setup_args.experiment,
config_file=setup_args.config_file,
checkpoint_path=setup_args.checkpoint_path,
credential_path=setup_args.credential_path or None,
parallelism_config=attrs.asdict(parallelism_config),
compile_config=attrs.asdict(compile_config),
load_ema_to_reg=setup_args.use_ema_weights,
experiment_opts=[
*setup_args.experiment_overrides,
f"model.config.rectified_flow_inference_config.scheduler_type={sampler_override}",
],
use_cache_checkpoint=setup_args.checkpoint_cache_dir is not None,
cache_checkpoint_rootdir=str(setup_args.checkpoint_cache_dir or ""),
)
model = cast("OmniMoTModel", model)
Cosmos3OmniModel.after_load_model(model)
save_config(config, setup_args.output_dir)
else:
model_dict = setup_args.load_model_config_dict()
if setup_args.vlm_processor_from_checkpoint:
# Source the VLM processor from the loaded checkpoint's own
# bundled files instead of the repository hardcoded in the
# model config. Drops the redundant base-model download.
tokenizer_cfg = model_dict["config"]["vlm_config"]["tokenizer"]
tokenizer_cfg.pop("repository", None)
tokenizer_cfg.pop("revision", None)
tokenizer_cfg.pop("subdir", None)
tokenizer_cfg["tokenizer_type"] = str(checkpoint_path)
config = Cosmos3OmniConfig(model=model_dict)
model = Cosmos3OmniModel.from_pretrained_dcp(
checkpoint_path,
config=config,
parallelism_config=parallelism_config,
compile_config=compile_config,
).model
if model.config.rectified_flow_inference_config.scheduler_type != sampler_override:
model.config.rectified_flow_inference_config.scheduler_type = sampler_override
model.set_up_scheduler_and_sampler()
log.debug(f"Sampler overridden to: {sampler_override}")
checkpoint_path = setup_args.download_checkpoint()
if setup_args.config_file_type == ConfigFileType.MODULE:
config = None
else:
model_dict = setup_args.load_model_config_dict()
if setup_args.vlm_processor_from_checkpoint:
# Source the VLM processor from the loaded checkpoint's own
# bundled files instead of the repository hardcoded in the
# model config. Drops the redundant base-model download.
tokenizer_cfg = model_dict["config"]["vlm_config"]["tokenizer"]
tokenizer_cfg.pop("repository", None)
tokenizer_cfg.pop("revision", None)
tokenizer_cfg.pop("subdir", None)
tokenizer_cfg["tokenizer_type"] = str(checkpoint_path)
config = Cosmos3OmniConfig(model=model_dict)
model = Cosmos3OmniModel.from_pretrained_dcp(
checkpoint_path,
config=config,
parallelism_config=parallelism_config,
compile_config=compile_config,
).model
if model.config.rectified_flow_inference_config.scheduler_type != sampler_override:
model.config.rectified_flow_inference_config.scheduler_type = sampler_override
model.set_up_scheduler_and_sampler()
log.debug(f"Sampler overridden to: {sampler_override}")

# Single-GPU CPU offloading (opt-in via --offload-stages). The diffusion parts
# (reasoner / generator / vae) time-share one reusable GPU arena; 'guardrails'
# is handled separately by the guardrail runners. Default-off; when off the
# model's ``memory`` stays None (joint path, unchanged).
arena_stages = tuple(s for s in setup_args.offload_stages if s in ARENA_OFFLOAD_PARTS)
if arena_stages:
if compile_config.use_cuda_graphs:
raise NotImplementedError(
"CPU offloading is incompatible with CUDA graphs (staging rebinds weight "
"tensors between calls, which breaks captured static addresses). "
"Disable --use-cuda-graphs or --offload-stages."
)
world = setup_args.dp_shard_size * setup_args.cp_size * setup_args.cfgp_size
if world != 1:
raise NotImplementedError(
f"CPU offloading is single-GPU only (dp_shard*cp*cfgp must be 1, got {world})."
)
available_parts = build_omni_offload_parts(model)
missing = [s for s in arena_stages if s not in available_parts]
if missing:
raise ValueError(f"Requested offload stage(s) {missing} are unavailable for this model.")
offloader = OffloadPipeline(
stages=arena_stages,
parts={s: available_parts[s] for s in arena_stages},
device=torch.device("cuda", torch.cuda.current_device()),
pin_memory=True,
)
offloader.initialize()
# Drive staging from the model/decode sites without leaking inference-side
# machinery into the model: callers invoke ``_offload_stage_fn(part)``, which
# is a no-op for parts that aren't being offloaded (so call sites are uniform).
model._offloader = offloader # keep alive
model._offload_stage_fn = lambda part: offloader.context(part) if offloader.has_part(part) else None
model._reasoner_generator_split = bool(
{REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART} & set(arena_stages)
)
log.info(f"Enabled single-GPU CPU offloading for stages: {', '.join(arena_stages)}.")

vae_decode_stream: torch.cuda.Stream | None = None
if setup_args.use_separate_pipeline_vision_decode_gpu:
Expand Down Expand Up @@ -1468,6 +1552,12 @@ def decode_vision(vision_latent: torch.Tensor) -> torch.Tensor:

with self._get_timer(f"{self.model.__class__.__name__}.decode"):
output_vision = outputs.pop("vision")
# Stage the VAE onto the GPU arena for decode (no-op unless 'vae' is an
# offloaded stage). The denoise loop left the generator staged; this
# rebinds it back to CPU and brings the VAE in.
_stage_fn = getattr(self.model, "_offload_stage_fn", None)
if _stage_fn is not None:
_stage_fn(VAE_OFFLOAD_PART)
decoded_vision = [decode_vision(vision) for vision in output_vision]
outputs["vision"] = [cast(torch.Tensor, vision) for vision in decoded_vision]
if self.vae_decode_stream is not None:
Expand Down
Loading