diff --git a/.agents/skills/cosmos3-inference/SKILL.md b/.agents/skills/cosmos3-inference/SKILL.md index 314b6c5..d840c3e 100644 --- a/.agents/skills/cosmos3-inference/SKILL.md +++ b/.agents/skills/cosmos3-inference/SKILL.md @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this | Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models | | Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities | | What parallelism preset? (latency vs throughput) | `README.md` § Inference | +| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading | | What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments | | What are the default parameter values? | `cosmos_framework/inference/defaults//sample_args.json` (per-modality JSON) | | How do I use custom defaults? | `docs/inference.md` § Custom Defaults | diff --git a/.claude/skills/cosmos3-inference/SKILL.md b/.claude/skills/cosmos3-inference/SKILL.md index 314b6c5..d840c3e 100644 --- a/.claude/skills/cosmos3-inference/SKILL.md +++ b/.claude/skills/cosmos3-inference/SKILL.md @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this | Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models | | Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities | | What parallelism preset? (latency vs throughput) | `README.md` § Inference | +| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading | | What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments | | What are the default parameter values? | `cosmos_framework/inference/defaults//sample_args.json` (per-modality JSON) | | How do I use custom defaults? | `docs/inference.md` § Custom Defaults | diff --git a/cosmos_framework/inference/args_test.py b/cosmos_framework/inference/args_test.py index 7e450ba..edc3c4d 100644 --- a/cosmos_framework/inference/args_test.py +++ b/cosmos_framework/inference/args_test.py @@ -129,6 +129,32 @@ def check_model_equal(actual: pydantic.BaseModel, expected: pydantic.BaseModel): check_model_equal(OmniSetupOverrides.model_validate(args.model_dump()).build_setup(), args) +def test_offload_stages(tmp_path: Path): + def _build(**kwargs): + return OmniSetupOverrides( + checkpoint_path=DEFAULT_CHECKPOINT_NAME, + output_dir=tmp_path / "outputs", + **kwargs, + ).build_setup() + + # Default: offloading disabled. + args = _build() + assert args.offload_stages == () + + # Arena stages round-trip through build_setup. + args = _build(offload_stages=("reasoner", "generator", "vae")) + assert args.offload_stages == ("reasoner", "generator", "vae") + + # Guardrail offloading is a separate flag, not an --offload-stages value. + for bad in (("guardrails",), ("bogus",)): + with pytest.raises(pydantic.ValidationError): + OmniSetupOverrides( + checkpoint_path=DEFAULT_CHECKPOINT_NAME, + output_dir=tmp_path / "outputs", + offload_stages=bad, + ) + + def test_sample_args(tmp_path: Path): setup_args = OmniSetupOverrides( checkpoint_path=DEFAULT_CHECKPOINT_NAME, diff --git a/cosmos_framework/inference/common/args.py b/cosmos_framework/inference/common/args.py index 8d7cd5a..466f973 100644 --- a/cosmos_framework/inference/common/args.py +++ b/cosmos_framework/inference/common/args.py @@ -47,6 +47,11 @@ Training = Suppress +# Single-GPU CPU-offload stages selectable via ``--offload-stages``. (Guardrail +# offloading has its own dedicated flag, ``--offload-guardrail-models``.) +OffloadStage = Literal["reasoner", "generator", "vae"] + + IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"] VIDEO_EXTENSIONS = [".mp4"] MEDIA_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS @@ -643,6 +648,7 @@ class SetupArgs(ABC, CheckpointArgs, ParallelismArgs, GuardrailArgs): warmup: pydantic.NonNegativeInt max_model_len: pydantic.PositiveInt | None max_num_seqs: pydantic.PositiveInt | None + offload_stages: tuple[OffloadStage, ...] # Subclass must implement these fields/methods # ------------------------------------------------------------ @@ -693,6 +699,14 @@ class SetupOverrides(ABC, CheckpointOverrides, ParallelismOverrides, GuardrailOv max_num_seqs: pydantic.PositiveInt | None = 1 """Maximum number of sequences per batch. When set, samples are packed into batches by number of sequences.""" + offload_stages: tuple[OffloadStage, ...] = () + """Single-GPU CPU-offload stages. Each named component is offloaded to pinned CPU + storage and staged into one reusable GPU arena only while in use, reducing peak + GPU memory. Choices: 'reasoner' / 'generator' (the MoT towers — enabling either + runs the understanding pathway once as a prefill that caches the per-layer K/V, then + runs the denoise loop generator-only) and 'vae' (the vision tokenizer, staged around + encode/decode). Empty = off (joint path, unchanged). Single-GPU only; incompatible + with CUDA graphs. Guardrail offloading has its own flag, --offload-guardrail-models.""" def _build_setup(self): pass diff --git a/cosmos_framework/inference/inference.py b/cosmos_framework/inference/inference.py index 1190889..b8d5162 100644 --- a/cosmos_framework/inference/inference.py +++ b/cosmos_framework/inference/inference.py @@ -37,6 +37,7 @@ from cosmos_framework.inference.common.inference import Inference, sync_distributed_errors from cosmos_framework.inference.common.init import get_rank, get_world_size from cosmos_framework.inference.model import Cosmos3OmniConfig, Cosmos3OmniModel +from cosmos_framework.inference.offloading import OffloadPipeline from cosmos_framework.inference.vision import ( build_conditioned_video_batch, build_image_edit_batch, @@ -50,7 +51,7 @@ from cosmos_framework.tools.visualize.video import save_img_or_video from cosmos_framework.configs.base.defaults.compile import CompileConfig from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig -from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel +from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel, cpu_offload_materialization from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING from cosmos_framework.model.vfm.upsampler.prompts import is_upsampled_prompt @@ -59,6 +60,43 @@ UpsampleTask = Literal["t2i", "t2v", "i2v"] +# Reasoner/generator CPU-offload group names (single-GPU). The two MoT pathways +# are interleaved per layer, so the groups gather the per-layer understanding vs +# generation submodules (plus the generation-side diffusion heads) by reference. +REASONER_OFFLOAD_PART = "reasoner" +GENERATOR_OFFLOAD_PART = "generator" +VAE_OFFLOAD_PART = "vae" +# Offload stages that ride the diffusion GPU arena (``guardrails`` is handled +# separately by the guardrail runners' own CPU-offload path). +ARENA_OFFLOAD_PARTS = (REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART, VAE_OFFLOAD_PART) + + +def build_omni_offload_parts(model: OmniMoTModel) -> dict[str, "torch.nn.Module"]: + """Group the in-tree Cosmos3 model into offloadable module groups. + + The reasoner/generator partition is owned by the network + (:meth:`Cosmos3VFMNetwork.offload_module_groups`) so the load-time CPU + materialization and this runtime ``OffloadPipeline`` share one source of truth. + Each group is wrapped in an ``nn.ModuleList`` referencing the existing submodule + objects (no module-tree changes), and the ``vae`` group is added for the vision + tokenizer network. All groups are disjoint; callers stage only the requested + subset, and modules left out of every staged group stay GPU-resident. + """ + import torch.nn as nn + + groups = model.net.offload_module_groups() # {"reasoner": [...], "generator": [...]} + parts: dict[str, nn.Module] = {name: nn.ModuleList(modules) for name, modules in groups.items()} + + # The vision tokenizer's underlying network (the offloadable VAE weights). + # ``tokenizer_vision_gen.model`` may itself be a thin (non-Module) wrapper whose + # ``.model`` is the actual nn.Module (e.g. Wan2pt2VAEInterface -> WanVAE -> WanVAE_). + vae = getattr(model.tokenizer_vision_gen, "model", None) + if vae is not None and not isinstance(vae, nn.Module): + vae = getattr(vae, "model", None) + if isinstance(vae, nn.Module): + parts[VAE_OFFLOAD_PART] = vae + return parts + _BatchItem = TypeVar("_BatchItem") @@ -1016,60 +1054,106 @@ def _create(cls, setup_args: SetupArgs, **kwargs: Any) -> Self: sampler_override = setup_args.sampler parallelism_config = cls._get_parallelism_config(setup_args) compile_config = cls._get_compile_config(setup_args) - if setup_args.checkpoint_type == CheckpointType.DCP and setup_args.config_file_type == ConfigFileType.MODULE: - from cosmos_framework.inference.common.config import save_config - from cosmos_framework.utils.vfm.model_loader import load_model_from_checkpoint - - if not setup_args.experiment: - raise ValueError("'experiment' is required") - if not setup_args.config_file: - raise ValueError("'config_file' is required") - - Cosmos3OmniModel.before_load_model() - model, config = load_model_from_checkpoint( - experiment_name=setup_args.experiment, - config_file=setup_args.config_file, - checkpoint_path=setup_args.checkpoint_path, - credential_path=setup_args.credential_path or None, - parallelism_config=attrs.asdict(parallelism_config), - compile_config=attrs.asdict(compile_config), - load_ema_to_reg=setup_args.use_ema_weights, - experiment_opts=[ - *setup_args.experiment_overrides, - f"model.config.rectified_flow_inference_config.scheduler_type={sampler_override}", - ], - use_cache_checkpoint=setup_args.checkpoint_cache_dir is not None, - cache_checkpoint_rootdir=str(setup_args.checkpoint_cache_dir or ""), - ) - model = cast("OmniMoTModel", model) - Cosmos3OmniModel.after_load_model(model) - save_config(config, setup_args.output_dir) - else: - checkpoint_path = setup_args.download_checkpoint() - if setup_args.config_file_type == ConfigFileType.MODULE: - config = None + # Two-phase materialization for single-GPU CPU offloading: build the reasoner/ + # generator towers directly on CPU during model construction so they never occupy + # GPU memory (the checkpoint shards load straight into CPU tensors). No-op when + # those stages aren't requested. + net_offload_parts = tuple( + s for s in setup_args.offload_stages if s in (REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART) + ) + with cpu_offload_materialization(net_offload_parts): + if setup_args.checkpoint_type == CheckpointType.DCP and setup_args.config_file_type == ConfigFileType.MODULE: + from cosmos_framework.inference.common.config import save_config + from cosmos_framework.utils.vfm.model_loader import load_model_from_checkpoint + + if not setup_args.experiment: + raise ValueError("'experiment' is required") + if not setup_args.config_file: + raise ValueError("'config_file' is required") + + Cosmos3OmniModel.before_load_model() + model, config = load_model_from_checkpoint( + experiment_name=setup_args.experiment, + config_file=setup_args.config_file, + checkpoint_path=setup_args.checkpoint_path, + credential_path=setup_args.credential_path or None, + parallelism_config=attrs.asdict(parallelism_config), + compile_config=attrs.asdict(compile_config), + load_ema_to_reg=setup_args.use_ema_weights, + experiment_opts=[ + *setup_args.experiment_overrides, + f"model.config.rectified_flow_inference_config.scheduler_type={sampler_override}", + ], + use_cache_checkpoint=setup_args.checkpoint_cache_dir is not None, + cache_checkpoint_rootdir=str(setup_args.checkpoint_cache_dir or ""), + ) + model = cast("OmniMoTModel", model) + Cosmos3OmniModel.after_load_model(model) + save_config(config, setup_args.output_dir) else: - model_dict = setup_args.load_model_config_dict() - if setup_args.vlm_processor_from_checkpoint: - # Source the VLM processor from the loaded checkpoint's own - # bundled files instead of the repository hardcoded in the - # model config. Drops the redundant base-model download. - tokenizer_cfg = model_dict["config"]["vlm_config"]["tokenizer"] - tokenizer_cfg.pop("repository", None) - tokenizer_cfg.pop("revision", None) - tokenizer_cfg.pop("subdir", None) - tokenizer_cfg["tokenizer_type"] = str(checkpoint_path) - config = Cosmos3OmniConfig(model=model_dict) - model = Cosmos3OmniModel.from_pretrained_dcp( - checkpoint_path, - config=config, - parallelism_config=parallelism_config, - compile_config=compile_config, - ).model - if model.config.rectified_flow_inference_config.scheduler_type != sampler_override: - model.config.rectified_flow_inference_config.scheduler_type = sampler_override - model.set_up_scheduler_and_sampler() - log.debug(f"Sampler overridden to: {sampler_override}") + checkpoint_path = setup_args.download_checkpoint() + if setup_args.config_file_type == ConfigFileType.MODULE: + config = None + else: + model_dict = setup_args.load_model_config_dict() + if setup_args.vlm_processor_from_checkpoint: + # Source the VLM processor from the loaded checkpoint's own + # bundled files instead of the repository hardcoded in the + # model config. Drops the redundant base-model download. + tokenizer_cfg = model_dict["config"]["vlm_config"]["tokenizer"] + tokenizer_cfg.pop("repository", None) + tokenizer_cfg.pop("revision", None) + tokenizer_cfg.pop("subdir", None) + tokenizer_cfg["tokenizer_type"] = str(checkpoint_path) + config = Cosmos3OmniConfig(model=model_dict) + model = Cosmos3OmniModel.from_pretrained_dcp( + checkpoint_path, + config=config, + parallelism_config=parallelism_config, + compile_config=compile_config, + ).model + if model.config.rectified_flow_inference_config.scheduler_type != sampler_override: + model.config.rectified_flow_inference_config.scheduler_type = sampler_override + model.set_up_scheduler_and_sampler() + log.debug(f"Sampler overridden to: {sampler_override}") + + # Single-GPU CPU offloading (opt-in via --offload-stages). The diffusion parts + # (reasoner / generator / vae) time-share one reusable GPU arena; 'guardrails' + # is handled separately by the guardrail runners. Default-off; when off the + # model's ``memory`` stays None (joint path, unchanged). + arena_stages = tuple(s for s in setup_args.offload_stages if s in ARENA_OFFLOAD_PARTS) + if arena_stages: + if compile_config.use_cuda_graphs: + raise NotImplementedError( + "CPU offloading is incompatible with CUDA graphs (staging rebinds weight " + "tensors between calls, which breaks captured static addresses). " + "Disable --use-cuda-graphs or --offload-stages." + ) + world = setup_args.dp_shard_size * setup_args.cp_size * setup_args.cfgp_size + if world != 1: + raise NotImplementedError( + f"CPU offloading is single-GPU only (dp_shard*cp*cfgp must be 1, got {world})." + ) + available_parts = build_omni_offload_parts(model) + missing = [s for s in arena_stages if s not in available_parts] + if missing: + raise ValueError(f"Requested offload stage(s) {missing} are unavailable for this model.") + offloader = OffloadPipeline( + stages=arena_stages, + parts={s: available_parts[s] for s in arena_stages}, + device=torch.device("cuda", torch.cuda.current_device()), + pin_memory=True, + ) + offloader.initialize() + # Drive staging from the model/decode sites without leaking inference-side + # machinery into the model: callers invoke ``_offload_stage_fn(part)``, which + # is a no-op for parts that aren't being offloaded (so call sites are uniform). + model._offloader = offloader # keep alive + model._offload_stage_fn = lambda part: offloader.context(part) if offloader.has_part(part) else None + model._reasoner_generator_split = bool( + {REASONER_OFFLOAD_PART, GENERATOR_OFFLOAD_PART} & set(arena_stages) + ) + log.info(f"Enabled single-GPU CPU offloading for stages: {', '.join(arena_stages)}.") vae_decode_stream: torch.cuda.Stream | None = None if setup_args.use_separate_pipeline_vision_decode_gpu: @@ -1468,6 +1552,12 @@ def decode_vision(vision_latent: torch.Tensor) -> torch.Tensor: with self._get_timer(f"{self.model.__class__.__name__}.decode"): output_vision = outputs.pop("vision") + # Stage the VAE onto the GPU arena for decode (no-op unless 'vae' is an + # offloaded stage). The denoise loop left the generator staged; this + # rebinds it back to CPU and brings the VAE in. + _stage_fn = getattr(self.model, "_offload_stage_fn", None) + if _stage_fn is not None: + _stage_fn(VAE_OFFLOAD_PART) decoded_vision = [decode_vision(vision) for vision in output_vision] outputs["vision"] = [cast(torch.Tensor, vision) for vision in decoded_vision] if self.vae_decode_stream is not None: diff --git a/cosmos_framework/inference/offloading.py b/cosmos_framework/inference/offloading.py new file mode 100644 index 0000000..bf766f6 --- /dev/null +++ b/cosmos_framework/inference/offloading.py @@ -0,0 +1,503 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +# +# Ported from NVIDIA/TensorRT-LLM PR #14095 +# (tensorrt_llm/_torch/visual_gen/offloading.py). +"""Module parameter offloading utilities for single-GPU inference. + +The offload path keeps model loading and quantization unchanged: weights are +loaded into the modules first, then selected module groups are copied into +packed CPU storage. At runtime one group at a time is staged into a reusable GPU +arena and the original module parameters/buffers are rebound to views of that +storage. + +This is the model-agnostic core (``ModuleOffloadManager`` + ``OffloadPipeline``). +The Cosmos3-specific wiring (which modules form the offload groups, and where +each group is staged during a generation) lives in ``OmniInference`` +(``cosmos_framework/inference/inference.py``). +""" + +import logging +import time +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Iterator, Mapping, Sequence + +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +def _align_offset(offset: int, alignment: int = 256) -> int: + return ((offset + alignment - 1) // alignment) * alignment + + +def _format_bytes(num_bytes: int) -> str: + return f"{num_bytes / (1024**3):.2f} GiB" + + +# FlashInfer and other custom kernels can require tensor data pointers to be at +# least 16-byte aligned even for smaller dtypes such as BF16. +_PACKED_TENSOR_ALIGNMENT = 16 + + +OffloadPipelineStage = tuple[str, ...] + + +@dataclass +class _FlatTensorSpec: + owner: nn.Module + name: str + qualified_name: str + is_parameter: bool + shape: tuple[int, ...] + stride: tuple[int, ...] + dtype: torch.dtype + requires_grad: bool + persistent: bool + offset: int + nbytes: int + + +@dataclass +class _GroupLayout: + """Packed storage layout and rebound views for one offload group.""" + + name: str + nbytes: int + specs: list[_FlatTensorSpec] + cpu_storage: torch.Tensor | None = None + cpu_views: tuple[nn.Parameter | torch.Tensor, ...] = () + gpu_views: tuple[nn.Parameter | torch.Tensor, ...] = () + + +class ModuleOffloadManager: + """Pack module groups into CPU storage and stage one group on GPU. + + The manager owns packed byte buffers: + - each layout owns ``cpu_storage`` for one offloaded group. + - ``gpu_arena`` is reused for whichever group is currently active. + + Initializing the manager packs and rebinds one group at a time. This + requires enough host memory to allocate the current group's packed CPU + storage before that group's original tensors are released. + """ + + def __init__( + self, + groups: Mapping[str, nn.Module], + device: torch.device | str, + pin_memory: bool = True, + ) -> None: + if not groups: + raise ValueError("At least one offload group must be provided") + + self.groups = dict(groups) + self.device = torch.device(device) + self.pin_memory = pin_memory + self.gpu_arena: torch.Tensor | None = None + self.layouts: dict[str, _GroupLayout] = {} + self.active_group_name: str | None = None + + for name, module in self.groups.items(): + if not name: + raise ValueError("Offload group names must be non-empty") + if not isinstance(module, nn.Module): + raise TypeError(f"Offload group '{name}' must contain an nn.Module") + + @staticmethod + def _owner_and_name(root: nn.Module, qualified_name: str) -> tuple[nn.Module, str]: + if "." not in qualified_name: + return root, qualified_name + module_path, name = qualified_name.rsplit(".", 1) + return root.get_submodule(module_path), name + + @staticmethod + def _tensor_nbytes(tensor: torch.Tensor) -> int: + return tensor.numel() * tensor.element_size() + + @staticmethod + def _storage_key(tensor: torch.Tensor) -> tuple[int, int] | None: + if tensor.numel() == 0: + return None + storage_offset_bytes = tensor.storage_offset() * tensor.element_size() + return tensor.untyped_storage().data_ptr(), storage_offset_bytes + + def _get_alias_spec( + self, + seen_tensors: dict[tuple[int, int], _FlatTensorSpec], + tensor: torch.Tensor, + display_name: str, + ) -> _FlatTensorSpec | None: + key = self._storage_key(tensor) + if key is None: + return None + canonical = seen_tensors.get(key) + if canonical is None: + return None + if self._tensor_nbytes(tensor) != canonical.nbytes or tensor.dtype != canonical.dtype: + raise ValueError( + "Shared parameters or buffers with different sizes or dtypes are " + f"not supported by ModuleOffloadManager: '{display_name}' aliases " + f"'{canonical.qualified_name}'" + ) + return canonical + + def _build_spec( + self, + group_name: str, + group_module: nn.Module, + qualified_name: str, + tensor: torch.Tensor, + is_parameter: bool, + offset: int, + ) -> _FlatTensorSpec: + display_name = f"{group_name}.{qualified_name}" + if not tensor.is_contiguous(): + raise ValueError( + f"Cannot offload non-contiguous tensor '{display_name}' with stride {tuple(tensor.stride())}" + ) + + owner, name = self._owner_and_name(group_module, qualified_name) + return _FlatTensorSpec( + owner=owner, + name=name, + qualified_name=display_name, + is_parameter=is_parameter, + shape=tuple(tensor.shape), + stride=tuple(tensor.stride()), + dtype=tensor.dtype, + requires_grad=tensor.requires_grad if is_parameter else False, + persistent=True if is_parameter else name not in owner._non_persistent_buffers_set, + offset=offset, + nbytes=self._tensor_nbytes(tensor), + ) + + def _iter_group_tensors(self, group_module: nn.Module) -> Iterator[tuple[str, torch.Tensor, bool]]: + for qualified_name, param in group_module.named_parameters(recurse=True, remove_duplicate=False): + yield qualified_name, param.detach(), True + for qualified_name, buffer in group_module.named_buffers(recurse=True, remove_duplicate=False): + yield qualified_name, buffer.detach(), False + + def _append_layout_spec( + self, + group_name: str, + group_module: nn.Module, + qualified_name: str, + tensor: torch.Tensor, + is_parameter: bool, + offset: int, + seen_tensors: dict[tuple[int, int], _FlatTensorSpec], + specs: list[_FlatTensorSpec], + ) -> int: + """Append a tensor spec and return the next group-local byte offset. + + This handles three layout concerns in one place: alias reuse, packed + tensor alignment, and spec construction. The offset is local to this + group and is shared by the CPU storage and reusable GPU arena views. + """ + display_name = f"{group_name}.{qualified_name}" + alias = self._get_alias_spec(seen_tensors, tensor, display_name) + if alias is None: + offset = _align_offset(offset, _PACKED_TENSOR_ALIGNMENT) + + spec = self._build_spec( + group_name=group_name, + group_module=group_module, + qualified_name=qualified_name, + tensor=tensor, + is_parameter=is_parameter, + offset=alias.offset if alias is not None else offset, + ) + specs.append(spec) + + if alias is not None: + return offset + + key = self._storage_key(tensor) + if key is not None: + seen_tensors[key] = spec + return offset + spec.nbytes + + def _collect_group_layout(self, group_name: str, group_module: nn.Module) -> _GroupLayout: + """Build the packed storage layout for one named module group.""" + offset = 0 + specs: list[_FlatTensorSpec] = [] + seen_tensors: dict[tuple[int, int], _FlatTensorSpec] = {} + + for qualified_name, tensor, is_parameter in self._iter_group_tensors(group_module): + offset = self._append_layout_spec( + group_name=group_name, + group_module=group_module, + qualified_name=qualified_name, + tensor=tensor, + is_parameter=is_parameter, + offset=offset, + seen_tensors=seen_tensors, + specs=specs, + ) + + if not specs: + raise ValueError(f"Offload group '{group_name}' has no parameters or buffers") + + return _GroupLayout(name=group_name, nbytes=_align_offset(offset), specs=specs) + + def _copy_group_to_cpu_storage(self, layout: _GroupLayout) -> None: + if layout.cpu_storage is None: + raise RuntimeError(f"CPU storage for offload group '{layout.name}' has not been allocated") + for spec in layout.specs: + if spec.nbytes == 0: + continue + try: + tensor = getattr(spec.owner, spec.name).detach() + tensor_bytes = tensor.reshape(-1).view(torch.uint8).cpu() + layout.cpu_storage.narrow(0, spec.offset, spec.nbytes).copy_(tensor_bytes) + except RuntimeError as e: + raise RuntimeError( + f"Failed to copy offload tensor '{spec.qualified_name}' " + f"({_format_bytes(spec.nbytes)}, shape={spec.shape}, dtype={spec.dtype}) " + f"to packed CPU storage at offset {spec.offset}." + ) from e + + def _group_size_summary(self) -> str: + return ", ".join(f"{name}={_format_bytes(layout.nbytes)}" for name, layout in self.layouts.items()) + + def _cuda_allocation_hint(self) -> str: + if self.device.type != "cuda": + return "" + return ( + " If this is due to CUDA memory fragmentation, try setting " + "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before starting the process." + ) + + def _allocate_cpu_storage(self, num_bytes: int, group_name: str | None = None) -> torch.Tensor: + try: + return torch.empty(num_bytes, dtype=torch.uint8, device="cpu", pin_memory=self.pin_memory) + except RuntimeError as e: + group_context = f", group='{group_name}'" if group_name is not None else "" + raise RuntimeError( + "Failed to allocate packed CPU storage for Cosmos3 offload " + f"({_format_bytes(num_bytes)}, {num_bytes} bytes{group_context}, " + f"pin_memory={self.pin_memory}, groups=[{self._group_size_summary()}])." + ) from e + + def _allocate_gpu_arena(self, num_bytes: int) -> torch.Tensor: + try: + return torch.empty(num_bytes, dtype=torch.uint8, device=self.device) + except RuntimeError as e: + raise RuntimeError( + "Failed to allocate GPU arena for Cosmos3 offload " + f"({_format_bytes(num_bytes)}, {num_bytes} bytes, " + f"device={self.device}, groups=[{self._group_size_summary()}])." + f"{self._cuda_allocation_hint()}" + ) from e + + def _make_views( + self, + layout: _GroupLayout, + storage: torch.Tensor, + ) -> tuple[nn.Parameter | torch.Tensor, ...]: + views: list[nn.Parameter | torch.Tensor] = [] + for spec in layout.specs: + view = storage.narrow(0, spec.offset, spec.nbytes).view(spec.dtype) + view = view.as_strided(spec.shape, spec.stride) + if spec.is_parameter: + views.append(nn.Parameter(view, requires_grad=spec.requires_grad)) + else: + views.append(view) + return tuple(views) + + def _bind_views( + self, + layout: _GroupLayout, + views: tuple[nn.Parameter | torch.Tensor, ...], + ) -> None: + for spec, view in zip(layout.specs, views, strict=True): + if spec.is_parameter: + if not isinstance(view, nn.Parameter): + raise TypeError( + f"Expected offload view '{spec.name}' to be an nn.Parameter, got {type(view).__name__}" + ) + spec.owner.register_parameter(spec.name, view) + else: + if not isinstance(view, torch.Tensor): + raise TypeError( + f"Expected offload view '{spec.name}' to be a torch.Tensor, got {type(view).__name__}" + ) + spec.owner.register_buffer(spec.name, view, persistent=spec.persistent) + + def initialize(self) -> None: + """Allocate packed storage, copy current tensors, and bind CPU views.""" + if self.layouts: + raise RuntimeError("ModuleOffloadManager has already been initialized") + + start_time = time.time() + for name, module in self.groups.items(): + layout = self._collect_group_layout(name, module) + self.layouts[name] = layout + + total_cpu_bytes = sum(layout.nbytes for layout in self.layouts.values()) + max_gpu_bytes = max(layout.nbytes for layout in self.layouts.values()) + logger.info( + "Module offload storage layout: " + f"cpu_total={_format_bytes(total_cpu_bytes)}, " + f"gpu_arena={_format_bytes(max_gpu_bytes)}, " + f"groups=[{self._group_size_summary()}], device={self.device}" + ) + + # Pack and rebind one group at a time. This keeps setup simple and fast: + # offloading requires enough host memory to allocate one group's packed + # CPU storage before that group's original tensors are released. + for layout in self.layouts.values(): + logger.info( + f"Module offload packing group into CPU storage: {layout.name} ({_format_bytes(layout.nbytes)})" + ) + layout.cpu_storage = self._allocate_cpu_storage(layout.nbytes, group_name=layout.name) + self._copy_group_to_cpu_storage(layout) + layout.cpu_views = self._make_views(layout, layout.cpu_storage) + self._rebind_to_cpu(layout.name) + + # Every group's parameters/buffers were just rebound to CPU storage, so the + # modules' original on-device weights are now unreferenced. Release that freed + # device memory back to the driver before allocating the arena, so the arena + # (and later activations) reclaim it instead of stacking on top of the + # caching allocator's now-fragmented free blocks. Without this, offloading can + # use MORE device memory than the joint path (the freed weights are cached, not + # returned, and large contiguous activations can't reuse the fragmented blocks). + if self.device.type == "cuda": + torch.cuda.synchronize(self.device) + torch.cuda.empty_cache() + + self.gpu_arena = self._allocate_gpu_arena(max_gpu_bytes) + for layout in self.layouts.values(): + layout.gpu_views = self._make_views(layout, self.gpu_arena) + + logger.info(f"Module offload setup completed in {time.time() - start_time:.2f}s") + + def _get_layout(self, name: str) -> _GroupLayout: + try: + return self.layouts[name] + except KeyError as e: + raise KeyError(f"Unknown offload group '{name}'. Available groups: {sorted(self.layouts)}") from e + + def stage(self, name: str) -> None: + """Stage one offload group into the GPU arena and rebind its tensors.""" + layout = self._get_layout(name) + if self.active_group_name == name: + return + if layout.cpu_storage is None or self.gpu_arena is None: + raise RuntimeError("ModuleOffloadManager must be initialized before staging") + + if self.active_group_name is not None: + self._rebind_to_cpu(self.active_group_name) + self.active_group_name = None + + src = layout.cpu_storage.narrow(0, 0, layout.nbytes) + dst = self.gpu_arena.narrow(0, 0, layout.nbytes) + try: + dst.copy_(src, non_blocking=layout.cpu_storage.is_pinned()) + if self.device.type == "cuda": + torch.cuda.synchronize(self.device) + except RuntimeError as e: + raise RuntimeError( + f"Failed to stage offload group '{name}' ({_format_bytes(layout.nbytes)}) to {self.device}" + ) from e + self._rebind_to_gpu(name) + self.active_group_name = name + + def _rebind_to_cpu(self, name: str) -> None: + layout = self._get_layout(name) + if not layout.cpu_views: + raise RuntimeError("ModuleOffloadManager must be initialized before staging") + self._bind_views(layout, layout.cpu_views) + + def _rebind_to_gpu(self, name: str) -> None: + layout = self._get_layout(name) + if not layout.gpu_views: + raise RuntimeError("ModuleOffloadManager must be initialized before staging") + self._bind_views(layout, layout.gpu_views) + + +class OffloadPipeline: + """Stage offload groups explicitly from model call-site contexts. + + This class intentionally does not use forward hooks. Pipeline code must wrap + the relevant call site with ``with offload_pipeline.context("group")`` so + staging happens before the model invocation and outside any later CUDA graph + capture. + """ + + def __init__( + self, + stages: Sequence[Sequence[str] | str], + parts: Mapping[str, nn.Module], + device: torch.device | str, + pin_memory: bool = True, + ) -> None: + if not stages: + raise ValueError("At least one offload pipeline stage must be provided") + + self.stages = tuple((stage,) if isinstance(stage, str) else tuple(stage) for stage in stages) + self.parts = dict(parts) + self.device = torch.device(device) + self.pin_memory = pin_memory + self.manager = ModuleOffloadManager( + groups=self._build_groups(), + device=self.device, + pin_memory=self.pin_memory, + ) + self._group_name_by_part = {part: self._stage_name(stage) for stage in self.stages for part in stage} + + def _build_groups(self) -> dict[str, nn.Module]: + groups: dict[str, nn.Module] = {} + for stage in self.stages: + group_name = self._stage_name(stage) + if not stage: + raise ValueError("Offload pipeline stages must have at least one part") + if group_name in groups: + raise ValueError(f"Duplicate offload pipeline stage: {group_name}") + + modules: list[nn.Module] = [] + for part_name in stage: + try: + part = self.parts[part_name] + except KeyError as e: + raise KeyError( + f"Unknown offload pipeline part '{part_name}' for stage " + f"'{group_name}'. Available parts: {sorted(self.parts)}" + ) from e + modules.append(part) + + group_module = modules[0] if len(modules) == 1 else nn.ModuleList(modules) + groups[group_name] = group_module + + return groups + + def initialize(self) -> None: + """Allocate and populate backing storage for all configured stages.""" + self.manager.initialize() + + @staticmethod + def _stage_name(stage: Sequence[str] | str) -> str: + return stage if isinstance(stage, str) else "+".join(stage) + + def context(self, part_or_group_name: str): + """Stage the group containing ``part_or_group_name`` and return a no-op context. + + The active group intentionally stays resident after the call site; the next + ``stage()`` rebinds it back to CPU before staging another group, and + ``cleanup()`` handles the final rebind when the pipeline exits. + """ + group_name = self._group_name_by_part.get(part_or_group_name, part_or_group_name) + self.manager.stage(group_name) + return nullcontext() + + def has_part(self, part_name: str) -> bool: + return part_name in self._group_name_by_part + + def cleanup(self) -> None: + """Return the active group to CPU-backed views.""" + if self.manager.active_group_name is not None: + self.manager._rebind_to_cpu(self.manager.active_group_name) + self.manager.active_group_name = None diff --git a/cosmos_framework/inference/offloading_test.py b/cosmos_framework/inference/offloading_test.py new file mode 100644 index 0000000..8fa78a3 --- /dev/null +++ b/cosmos_framework/inference/offloading_test.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Unit tests for the model-agnostic CPU-offload core. + +Covers ``ModuleOffloadManager`` / ``OffloadPipeline``: packing disjoint module +groups into pinned CPU storage, time-sharing a single GPU arena via ``stage`` / +``context``, the numerical transparency of offloading (staged forward == +non-offloaded forward), parameter device placement across stages, and the +error/guard paths. + +The Cosmos3-specific wiring (``build_omni_offload_parts`` + the reasoner/ +generator split) is validated end-to-end against the real checkpoint; here we +pin down the reusable mechanics with a tiny, deterministic model. +""" + +import pytest +import torch +import torch.nn as nn + +from cosmos_framework.inference.offloading import ModuleOffloadManager, OffloadPipeline + + +def _device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class _TwoTower(nn.Module): + """Two disjoint single-layer towers exercised one at a time.""" + + def __init__(self, dim: int = 16) -> None: + super().__init__() + self.tower_a = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim)) + self.tower_b = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim)) + + def run_a(self, x: torch.Tensor) -> torch.Tensor: + return self.tower_a(x) + + def run_b(self, x: torch.Tensor) -> torch.Tensor: + return self.tower_b(x) + + +def _build(device: torch.device) -> tuple[_TwoTower, torch.Tensor, torch.Tensor]: + torch.manual_seed(0) + model = _TwoTower().to(device=device, dtype=torch.float32).eval() + x = torch.randn(4, 16, device=device, dtype=torch.float32) + with torch.no_grad(): + ref = model.run_b(model.run_a(x)) # both towers resident on the device + return model, x, ref + + +def test_offload_pipeline_is_numerically_transparent(): + """Staging each group and running it must reproduce the non-offloaded forward.""" + device = _device() + model, x, ref = _build(device) + + offloader = OffloadPipeline( + stages=("tower_a", "tower_b"), + parts={"tower_a": model.tower_a, "tower_b": model.tower_b}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + + with torch.no_grad(): + offloader.context("tower_a") + out_a = model.run_a(x) + offloader.context("tower_b") + out_b = model.run_b(out_a) + offloader.cleanup() + + torch.testing.assert_close(out_b, ref, atol=0.0, rtol=0.0) + + +def test_stage_swaps_parameter_device_placement(): + """Only the active group's tensors live on the arena; the rest are CPU-backed.""" + device = _device() + model, _, _ = _build(device) + + offloader = OffloadPipeline( + stages=("tower_a", "tower_b"), + parts={"tower_a": model.tower_a, "tower_b": model.tower_b}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + + # After initialize (nothing staged) both groups are rebound to CPU storage. + assert model.tower_a[0].weight.device.type == "cpu" + assert model.tower_b[0].weight.device.type == "cpu" + + offloader.context("tower_a") + assert model.tower_a[0].weight.device.type == device.type + assert model.tower_b[0].weight.device.type == "cpu" + + offloader.context("tower_b") + assert model.tower_a[0].weight.device.type == "cpu" + assert model.tower_b[0].weight.device.type == device.type + + # Re-staging the already-active group is a no-op. + offloader.context("tower_b") + assert model.tower_b[0].weight.device.type == device.type + + offloader.cleanup() + assert model.tower_b[0].weight.device.type == "cpu" + + +def test_arena_is_shared_across_groups(): + """The GPU arena is reused: staged tensors of either group alias one buffer.""" + device = _device() + model, _, _ = _build(device) + offloader = OffloadPipeline( + stages=("tower_a", "tower_b"), + parts={"tower_a": model.tower_a, "tower_b": model.tower_b}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + + manager = offloader.manager + assert manager.gpu_arena is not None + arena_ptr = manager.gpu_arena.untyped_storage().data_ptr() + + def _in_arena(t: torch.Tensor) -> bool: + return t.untyped_storage().data_ptr() == arena_ptr + + offloader.context("tower_a") + assert _in_arena(model.tower_a[0].weight) + offloader.context("tower_b") + assert _in_arena(model.tower_b[0].weight) + offloader.cleanup() + + +def test_heterogeneous_group_sizes_share_one_arena(): + """A small (vae-like) group co-exists with large (tower-like) groups in one arena. + + Mirrors the Cosmos3 stage mix: the arena is sized to the largest group and the + smaller group stages into the same buffer (only its own bytes are copied), with + forward output unchanged. + """ + device = _device() + torch.manual_seed(0) + big_a = nn.Sequential(nn.Linear(64, 64), nn.Linear(64, 64)).to(device).eval() + big_b = nn.Sequential(nn.Linear(64, 64), nn.Linear(64, 64)).to(device).eval() + small = nn.Linear(64, 64).to(device).eval() # vae-like: much smaller + x = torch.randn(2, 64, device=device) + with torch.no_grad(): + ref_a, ref_b, ref_small = big_a(x), big_b(x), small(x) + + offloader = OffloadPipeline( + stages=("big_a", "big_b", "small"), + parts={"big_a": big_a, "big_b": big_b, "small": small}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + + # Arena is sized to the largest group, not the sum. + sizes = {name: layout.nbytes for name, layout in offloader.manager.layouts.items()} + assert offloader.manager.gpu_arena.numel() == max(sizes.values()) + assert sizes["small"] < sizes["big_a"] + + with torch.no_grad(): + offloader.context("big_a") + out_a = big_a(x) + offloader.context("small") + out_small = small(x) + offloader.context("big_b") + out_b = big_b(x) + offloader.cleanup() + + torch.testing.assert_close(out_a, ref_a, atol=0.0, rtol=0.0) + torch.testing.assert_close(out_b, ref_b, atol=0.0, rtol=0.0) + torch.testing.assert_close(out_small, ref_small, atol=0.0, rtol=0.0) + + +def test_has_part_and_tolerant_staging(): + """``has_part`` + guarded ``context`` is the tolerant-staging idiom the inference + layer uses so a uniform ``stage(part)`` call no-ops for parts that aren't offloaded.""" + device = _device() + model, _, _ = _build(device) + offloader = OffloadPipeline( + stages=("tower_a",), # only one part is offloaded + parts={"tower_a": model.tower_a}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + + assert offloader.has_part("tower_a") + assert not offloader.has_part("tower_b") + assert not offloader.has_part("vae") + + def stage(part: str): + # The exact wrapper installed by OmniInference._create. + return offloader.context(part) if offloader.has_part(part) else None + + stage("vae") # absent part -> silent no-op, nothing staged + assert offloader.manager.active_group_name is None + stage("tower_a") # present part -> staged + assert offloader.manager.active_group_name == "tower_a" + offloader.cleanup() + + +def test_unknown_part_raises(): + device = _device() + model, _, _ = _build(device) + offloader = OffloadPipeline( + stages=("tower_a", "tower_b"), + parts={"tower_a": model.tower_a, "tower_b": model.tower_b}, + device=device, + pin_memory=(device.type == "cuda"), + ) + offloader.initialize() + with pytest.raises(KeyError): + offloader.context("does_not_exist") + offloader.cleanup() + + +def test_double_initialize_raises(): + device = _device() + model, _, _ = _build(device) + manager = ModuleOffloadManager( + groups={"tower_a": model.tower_a, "tower_b": model.tower_b}, + device=device, + pin_memory=(device.type == "cuda"), + ) + manager.initialize() + with pytest.raises(RuntimeError): + manager.initialize() + + +def test_empty_groups_rejected(): + with pytest.raises(ValueError): + ModuleOffloadManager(groups={}, device=_device()) + + +def test_param_free_group_rejected(): + """A group with no parameters or buffers cannot be packed.""" + device = _device() + manager = ModuleOffloadManager( + groups={"empty": nn.ReLU(), "real": nn.Linear(4, 4)}, + device=device, + pin_memory=False, + ) + with pytest.raises(ValueError): + manager.initialize() diff --git a/cosmos_framework/model/vfm/mot/attention.py b/cosmos_framework/model/vfm/mot/attention.py index 002ca62..ac51f1e 100644 --- a/cosmos_framework/model/vfm/mot/attention.py +++ b/cosmos_framework/model/vfm/mot/attention.py @@ -296,6 +296,104 @@ def block_flex_attention( return from_joint(packed_attn_output, packed_query_states) +def _two_way_attention_gen_only( + packed_query_states: FactoredSequencePack | JointSequencePack, + packed_key_states: FactoredSequencePack | JointSequencePack, + packed_value_states: FactoredSequencePack | JointSequencePack, + memory_value: MemoryValue, +): + """Generator-only two-way attention using cached understanding K/V. + + Used by the reasoner/generator-split offload path: the understanding + (causal) sequence is empty, and the generation queries attend to + ``cat(cached und K/V, fresh gen K/V)`` — the same key/value set the joint + ``two_way_attention`` full path sees, with the understanding K/V supplied + from the prefill cache instead of recomputed. Single packed sample only. + """ + und_k = getattr(memory_value, "und_k", None) + und_v = getattr(memory_value, "und_v", None) + assert und_k is not None and und_v is not None, ( + "Generator-only attention requires a populated reasoner K/V cache; run the prefill pass first." + ) + + full_q, full_q_offsets = get_full_only_seq(packed_query_states) # [N_gen,heads,head_dim] + # Take only the generation (full) tokens. ``get_all_seq`` would re-expand to + # the full packed length, zero-filling the (empty) understanding slots, which + # would inject spurious all-zero keys/values into the attention; ``get_full_only_seq`` + # returns exactly the N_gen generation tokens. + gen_k, _ = get_full_only_seq(packed_key_states) # [N_gen,kv_heads,head_dim] + gen_v, _ = get_full_only_seq(packed_value_states) # [N_gen,kv_heads,head_dim] + + assert full_q_offsets.numel() == 2, ( + "Reasoner/generator-split offload supports a single packed sample per forward." + ) + + und_k = und_k.squeeze(0).to(gen_k.dtype) # [und_len,kv_heads,head_dim] + und_v = und_v.squeeze(0).to(gen_v.dtype) # [und_len,kv_heads,head_dim] + kv_k = torch.cat([und_k, gen_k], dim=0) # [und_len+N_gen,kv_heads,head_dim] + kv_v = torch.cat([und_v, gen_v], dim=0) # [und_len+N_gen,kv_heads,head_dim] + + n_gen = full_q.shape[0] + n_kv = kv_k.shape[0] + kv_offsets = torch.tensor([0, n_kv], device=full_q.device, dtype=full_q_offsets.dtype) + + full_res = attention( + full_q.unsqueeze(0), # [1,N_gen,heads,head_dim] + kv_k.unsqueeze(0), # [1,und_len+N_gen,kv_heads,head_dim] + kv_v.unsqueeze(0), # [1,und_len+N_gen,kv_heads,head_dim] + cumulative_seqlen_Q=full_q_offsets, + cumulative_seqlen_KV=kv_offsets, + max_seqlen_Q=n_gen, + max_seqlen_KV=n_kv, + ) # [1,N_gen,heads,head_dim] + full_out = full_res.squeeze(0).flatten(-2, -1) # [N_gen,heads*head_dim] + causal_out = full_out.new_empty(0, full_out.shape[-1]) # understanding is empty + return from_mode_splits(causal_out, full_out, packed_query_states) + + +def _two_way_attention_und_only( + packed_query_states: FactoredSequencePack | JointSequencePack, + packed_key_states: FactoredSequencePack | JointSequencePack, + packed_value_states: FactoredSequencePack | JointSequencePack, +): + """Understanding-only two-way attention used during the reasoner prefill. + + The generation (full) sequence is empty, so only the causal understanding + self-attention runs — identical to the causal half of ``two_way_attention``. + Computing only the causal path avoids invoking ``attention`` with an empty + full query. ``PackedAttentionMoT.forward`` captures the post-RoPE + understanding K/V into the cache via ``kv_to_store``. + + In the joint two-way path the understanding (causal) pathway attends to + understanding keys only (``is_causal=True`` over the causal sequence) and is + completely independent of the generation tokens, so the per-layer und K/V + and und hidden states produced here are bit-for-bit the same as the joint + pass — which is what makes the generator-only denoise numerically match. + """ + causal_q, causal_q_offsets = get_causal_seq(packed_query_states) + causal_k, causal_k_offsets = get_causal_seq(packed_key_states) + causal_v, _ = get_causal_seq(packed_value_states) + + use_dont_care_mask = causal_q_offsets is causal_k_offsets + + causal_res = attention( + causal_q.unsqueeze(0), # [1,N_und,heads,head_dim] + causal_k.unsqueeze(0), # [1,N_und,heads,head_dim] + causal_v.unsqueeze(0), # [1,N_und,heads,head_dim] + cumulative_seqlen_Q=causal_q_offsets, + cumulative_seqlen_KV=causal_k_offsets, + max_seqlen_Q=packed_query_states["max_causal_len"], + max_seqlen_KV=packed_query_states["max_causal_len"], + is_causal=True, + causal_type=CausalType.DontCare if use_dont_care_mask else CausalType.TopLeft, + ) # [1,N_und,heads,head_dim] + causal_out = causal_res.squeeze(0).flatten(-2, -1) # type: ignore # [N_und,heads*head_dim] + + # The generation (full) sequence is empty in the prefill. + full_out = causal_out.new_empty(0, causal_out.shape[-1]) # [0,heads*head_dim] + return from_mode_splits(causal_out, full_out, packed_query_states) + + def dispatch_attention( packed_query_states: FactoredSequencePack | JointSequencePack, packed_key_states: FactoredSequencePack | JointSequencePack, @@ -304,7 +402,24 @@ def dispatch_attention( natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, ) -> tuple[FactoredSequencePack | JointSequencePack, KVToStore | None]: - assert memory_value is None, "Base dispatch_attention does not handle MemoryValue" + if memory_value is not None: + # Reasoner/generator-split offload path. Only the two-way (full + # attention) layout is supported; NATTEN / three-way is rejected. + assert isinstance(attention_mask, SplitInfo) and not attention_mask.is_three_way, ( + "Reasoner/generator-split offload supports only the two_way attention path " + "(joint_attn_implementation='two_way', video_temporal_causal=False)." + ) + assert natten_metadata is None, "Reasoner/generator-split offload does not support NATTEN." + causal_q, _ = get_causal_seq(packed_query_states) + if causal_q.shape[0] == 0: + # Generator-only denoise step: attend gen queries to cached und K/V + fresh gen K/V. + return _two_way_attention_gen_only( + packed_query_states, packed_key_states, packed_value_states, memory_value + ), None + # Reasoner prefill: understanding-only causal self-attention. + # PackedAttentionMoT captures the understanding K/V into the cache via kv_to_store. + return _two_way_attention_und_only(packed_query_states, packed_key_states, packed_value_states), None + if isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way: output = three_way_attention( packed_query_states, diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 03f0c3f..5d8985f 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -269,6 +269,52 @@ def init_weights(self, buffer_device: torch.device | None): self.language_model.init_weights(buffer_device=buffer_device) + def offload_module_groups(self) -> dict[str, list[torch.nn.Module]]: + """Partition the network into the reasoner vs generator CPU-offload groups. + + The MoT interleaves both pathways per layer, so each group gathers the + per-layer understanding (reasoner) vs generation submodules — plus the + generation-side diffusion encode/decode heads, which the reasoner prefill + never touches. Returns references to the existing submodule objects (no + module-tree changes); the two groups are disjoint, as the offload manager + and the two-phase materialization both require. Single source of truth for + both the load-time CPU materialization and the runtime ``OffloadPipeline``. + """ + lm = self.language_model.model # Qwen3VL(Moe)TextModel + reasoner: list[torch.nn.Module] = [lm.embed_tokens, lm.norm] + generator: list[torch.nn.Module] = [lm.norm_moe_gen] + for head_name in ("time_embedder", "vae2llm", "llm2vae", "action2llm", "llm2action", "sound2llm", "llm2sound"): + head = getattr(self, head_name, None) + if isinstance(head, torch.nn.Module): + generator.append(head) + for raw_layer in lm.layers: + # Unwrap torch.compile's OptimizedModule so we reference the real submodules. + layer = getattr(raw_layer, "_orig_mod", raw_layer) + attn = layer.self_attn + reasoner += [ + attn.q_proj, + attn.k_proj, + attn.v_proj, + attn.o_proj, + attn.q_norm, + attn.k_norm, + layer.input_layernorm, + layer.post_attention_layernorm, + layer.mlp, + ] + generator += [ + attn.q_proj_moe_gen, + attn.k_proj_moe_gen, + attn.v_proj_moe_gen, + attn.o_proj_moe_gen, + attn.q_norm_moe_gen, + attn.k_norm_moe_gen, + layer.input_layernorm_moe_gen, + layer.post_attention_layernorm_moe_gen, + layer.mlp_moe_gen, + ] + return {"reasoner": reasoner, "generator": generator} + def generate_reasoner_text( self, input_ids: torch.Tensor, @@ -600,6 +646,27 @@ def _encode_text( ) return packed_sequence, packed_text_embedding.dtype + def _alloc_hidden_states_gen_only( + self, + packed_seq: PackedSequence, + ) -> tuple[torch.Tensor, torch.dtype]: + """Allocate a zeroed packed-sequence buffer without the reasoner ``embed_tokens``. + + Used by the generator-only denoise path: only the generation (vision) hidden + states are needed and the understanding K/V come from the prefilled cache, so + the text embeddings (a reasoner weight) are not required and the reasoner can + stay offloaded on CPU. The understanding rows are left as zeros (the generator + pathway never reads them). + """ + # Use the embed_tokens weight dtype/device as the reference compute dtype. + # Reading ``.dtype`` / ``.device`` is metadata-only and does not move or + # invoke the (possibly offloaded) reasoner weight. + ref = self.language_model.model.embed_tokens.weight + packed_sequence = torch.zeros( + (packed_seq.sequence_length, self.hidden_size), device=packed_seq.text_indexes.device, dtype=ref.dtype + ) # [N_total,hidden_size] + return packed_sequence, ref.dtype + def _encode_vision( self, packed_seq: PackedSequence, @@ -988,19 +1055,34 @@ def forward( # This is intentional for proper batch norm / dropout behavior # assert self.training, "Cosmos3VFMNetwork only supports training mode" - packed_sequence, target_dtype = self._encode_text(packed_seq) # packed_sequence: [N_total,hidden_size] + # Reasoner/generator-split offload modes (``memory`` is a ReasonerMemoryState): + # - prefill (und_only): encode text only; skip the generation-side encoders so + # the generator weights stay offloaded. The reasoner populates the per-layer + # understanding K/V cache and the forward returns early (no generation output). + # - gen (gen_only): skip the reasoner ``embed_tokens`` (offloaded); the + # understanding K/V come from the prefilled cache. + _und_only = memory is not None and memory.is_und_only() + _gen_only = memory is not None and memory.is_gen_only() + + if _gen_only: + packed_sequence, target_dtype = self._alloc_hidden_states_gen_only(packed_seq) + else: + packed_sequence, target_dtype = self._encode_text(packed_seq) # packed_sequence: [N_total,hidden_size] - # encode vision tokens + # encode vision/action/sound tokens. Skipped during the reasoner prefill: the + # understanding pathway does not attend to generation tokens, so their hidden + # states are irrelevant there and the generation-side encoders (vae2llm, + # time_embedder, ...) must stay offloaded. original_latent_shapes: List[Tuple[int, int, int]] | None = None - if self.config.vision_gen: + if self.config.vision_gen and not _und_only: original_latent_shapes = self._encode_vision(packed_seq, packed_sequence, target_dtype, fps_vision) # encode action tokens - if self.config.action_gen: + if self.config.action_gen and not _und_only: self._encode_action(packed_seq, packed_sequence, target_dtype, fps_action) # encode sound tokens - if self.config.sound_gen: + if self.config.sound_gen and not _und_only: self._encode_sound(packed_seq, packed_sequence, target_dtype, fps_sound) assert packed_seq.attn_modes is not None @@ -1076,6 +1158,14 @@ def forward( natten_metadata_list=natten_metadata_list, memory=memory, ) + + if _und_only: + # Reasoner prefill: ``self.language_model`` has populated the per-layer + # understanding K/V cache via ``memory.write_for_layer`` (side effect). + # There is no generation output to decode, so return early before any + # generation-side decode head runs (keeps the generator offloaded). + return dict() + last_hidden_state = get_context_parallel_last_hidden_state( packed_outputs=packed_outputs, parallel_dims=self.parallel_dims, diff --git a/cosmos_framework/model/vfm/mot/reasoner_generator_split_test.py b/cosmos_framework/model/vfm/mot/reasoner_generator_split_test.py new file mode 100644 index 0000000..7e1838d --- /dev/null +++ b/cosmos_framework/model/vfm/mot/reasoner_generator_split_test.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 +"""Numerical parity gate for the reasoner/generator split used by CPU offloading. + +The offloaded denoise path runs the understanding ("reasoner") tower once as a +prefill (``ReasonerMemoryState`` mode ``"prefill"``) that caches the per-layer +understanding K/V, then runs the generation tower alone (mode ``"gen"``) reusing +that cache. This test asserts the generator output of the split path matches the +standard joint forward on identical inputs, so enabling offloading cannot change +results. + +Uses a tiny dense ``Qwen3VLTextModel`` in fp32 (no checkpoint required) and the +``two_way`` attention path (the only layout the split supports). +""" + +from typing import cast + +import torch + +from cosmos_framework.data.vfm.sequence_packing import get_gen_seq +from cosmos_framework.model.vfm.mot.attention import build_packed_sequence +from cosmos_framework.model.vfm.mot.unified_mot import Qwen3VLTextModel +from cosmos_framework.model.vfm.utils.memory import ReasonerMemoryState +from cosmos_framework.model.vfm.vlm.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig + + +def _device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _build_tiny_text_model(device: torch.device) -> Qwen3VLTextConfig: + # head_dim * num_attention_heads == hidden_size; small dims for a fast test. + config = Qwen3VLTextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + rope_scaling=None, + rms_norm_eps=1e-6, + ) + model = Qwen3VLTextModel(config, qk_norm_for_text=True, qk_norm_for_diffusion=True) + return model.to(device=device, dtype=torch.float32).eval() + + +def _build_two_way_pack(model: Qwen3VLTextModel, device: torch.device, und_len: int = 6, gen_len: int = 10): + """Single-sample two_way pack: a causal understanding block then a full generation block.""" + cfg = model.config + total = und_len + gen_len + hidden = torch.randn(total, cfg.hidden_size, device=device, dtype=torch.float32) + + packed_und_token_indexes = cast(torch.LongTensor, torch.arange(0, und_len, device=device, dtype=torch.long)) + packed_gen_token_indexes = cast(torch.LongTensor, torch.arange(und_len, total, device=device, dtype=torch.long)) + position_ids = torch.arange(total, device=device, dtype=torch.int32) + + input_pack, attention_meta, natten = build_packed_sequence( + "two_way", + packed_sequence=hidden, + attn_modes=["causal", "full"], + split_lens=[und_len, gen_len], + sample_lens=[total], + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_gen_token_indexes, + num_heads=cfg.num_attention_heads, + head_dim=cfg.head_dim, + num_layers=cfg.num_hidden_layers, + token_shapes=[(1, 1, gen_len)], + ) + assert natten is None + return input_pack, attention_meta, position_ids + + +def test_gen_only_matches_joint(): + device = _device() + torch.manual_seed(0) + model = _build_tiny_text_model(device) + input_pack, attention_meta, position_ids = _build_two_way_pack(model, device) + + with torch.no_grad(): + joint_out, _ = model(input_pack, attention_mask=attention_meta, position_ids=position_ids, memory=None) + gen_joint = get_gen_seq(joint_out) + + memory = ReasonerMemoryState(model.config.num_hidden_layers) + memory.set_mode("prefill") + model(input_pack, attention_mask=attention_meta, position_ids=position_ids, memory=memory) + assert memory.is_initialized, "reasoner K/V cache not fully populated after prefill" + + memory.set_mode("gen") + split_out, _ = model(input_pack, attention_mask=attention_meta, position_ids=position_ids, memory=memory) + gen_split = get_gen_seq(split_out) + + torch.testing.assert_close(gen_split, gen_joint, atol=1e-4, rtol=1e-4) + + +def test_joint_path_unchanged_when_memory_none(): + """memory=None must be deterministic and identical run-to-run (joint path untouched).""" + device = _device() + torch.manual_seed(0) + model = _build_tiny_text_model(device) + input_pack, attention_meta, position_ids = _build_two_way_pack(model, device) + + with torch.no_grad(): + out1, _ = model(input_pack, attention_mask=attention_meta, position_ids=position_ids, memory=None) + out2, _ = model(input_pack, attention_mask=attention_meta, position_ids=position_ids, memory=None) + + torch.testing.assert_close(get_gen_seq(out1), get_gen_seq(out2), atol=0.0, rtol=0.0) diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 4908e7a..1657da6 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -527,46 +527,52 @@ def forward( memory_value: Optional read-only tensor container for memory-augmented attention. """ - q_und_in = self.q_proj(get_und_seq(pack)) # [N_und,num_heads*head_dim] - q_gen_in = self.q_proj_moe_gen(get_gen_seq(pack)) # [N_gen,num_heads*head_dim] - - k_und_in = self.k_proj(get_und_seq(pack)) # [N_und,num_kv_heads*head_dim] - k_gen_in = self.k_proj_moe_gen(get_gen_seq(pack)) # [N_gen,num_kv_heads*head_dim] - - v_und_in = self.v_proj(get_und_seq(pack)) # [N_und,num_kv_heads*head_dim] - v_gen_in = self.v_proj_moe_gen(get_gen_seq(pack)) # [N_gen,num_kv_heads*head_dim] - - q_und = q_und_in.view(-1, self.num_attention_heads, self.head_dim) # [N_und,num_heads,head_dim] - k_und = k_und_in.view(-1, self.num_key_value_heads, self.head_dim) # [N_und,num_kv_heads,head_dim] - v_und = v_und_in.view(-1, self.num_key_value_heads, self.head_dim) # [N_und,num_kv_heads,head_dim] - - q_gen = q_gen_in.view(-1, self.num_attention_heads, self.head_dim) # [N_gen,num_heads,head_dim] - k_gen = k_gen_in.view(-1, self.num_key_value_heads, self.head_dim) # [N_gen,num_kv_heads,head_dim] - v_gen = v_gen_in.view(-1, self.num_key_value_heads, self.head_dim) # [N_gen,num_kv_heads,head_dim] - - q_und = self.q_norm(q_und) # [N_und,num_heads,head_dim] - k_und = self.k_norm(k_und) # [N_und,num_kv_heads,head_dim] - - q_gen = self.q_norm_moe_gen(q_gen) # [N_gen,num_heads,head_dim] - k_gen = self.k_norm_moe_gen(k_gen) # [N_gen,num_kv_heads,head_dim] + und_seq = get_und_seq(pack) # [N_und,hidden_size] + gen_seq = get_gen_seq(pack) # [N_gen,hidden_size] + # True-skip each pathway when its sequence is empty so the corresponding + # weights are never invoked (required for reasoner/generator CPU offload). + # When both are non-empty (the joint path) behavior is unchanged. + has_und = und_seq.shape[0] > 0 + has_gen = gen_seq.shape[0] > 0 packed_cos = packed_position_embeddings[0] packed_sin = packed_position_embeddings[1] - q_und_, k_und_ = self._apply_rotary_pos_emb( - q_und, - k_und, - get_und_seq(packed_cos), - get_und_seq(packed_sin), - unsqueeze_dim=1, - ) # q_und_: [N_und,num_heads,head_dim], k_und_: [N_und,num_kv_heads,head_dim] - q_gen_, k_gen_ = self._apply_rotary_pos_emb( - q_gen, - k_gen, - get_gen_seq(packed_cos), - get_gen_seq(packed_sin), - unsqueeze_dim=1, - ) # q_gen_: [N_gen,num_heads,head_dim], k_gen_: [N_gen,num_kv_heads,head_dim] + if has_und: + q_und = self.q_proj(und_seq).view(-1, self.num_attention_heads, self.head_dim) + k_und = self.k_proj(und_seq).view(-1, self.num_key_value_heads, self.head_dim) + v_und = self.v_proj(und_seq).view(-1, self.num_key_value_heads, self.head_dim) + q_und = self.q_norm(q_und) # [N_und,num_heads,head_dim] + k_und = self.k_norm(k_und) # [N_und,num_kv_heads,head_dim] + q_und_, k_und_ = self._apply_rotary_pos_emb( + q_und, + k_und, + get_und_seq(packed_cos), + get_und_seq(packed_sin), + unsqueeze_dim=1, + ) # q_und_: [N_und,num_heads,head_dim], k_und_: [N_und,num_kv_heads,head_dim] + else: + q_und_ = und_seq.new_empty(0, self.num_attention_heads, self.head_dim) + k_und_ = und_seq.new_empty(0, self.num_key_value_heads, self.head_dim) + v_und = und_seq.new_empty(0, self.num_key_value_heads, self.head_dim) + + if has_gen: + q_gen = self.q_proj_moe_gen(gen_seq).view(-1, self.num_attention_heads, self.head_dim) + k_gen = self.k_proj_moe_gen(gen_seq).view(-1, self.num_key_value_heads, self.head_dim) + v_gen = self.v_proj_moe_gen(gen_seq).view(-1, self.num_key_value_heads, self.head_dim) + q_gen = self.q_norm_moe_gen(q_gen) # [N_gen,num_heads,head_dim] + k_gen = self.k_norm_moe_gen(k_gen) # [N_gen,num_kv_heads,head_dim] + q_gen_, k_gen_ = self._apply_rotary_pos_emb( + q_gen, + k_gen, + get_gen_seq(packed_cos), + get_gen_seq(packed_sin), + unsqueeze_dim=1, + ) # q_gen_: [N_gen,num_heads,head_dim], k_gen_: [N_gen,num_kv_heads,head_dim] + else: + q_gen_ = gen_seq.new_empty(0, self.num_attention_heads, self.head_dim) + k_gen_ = gen_seq.new_empty(0, self.num_key_value_heads, self.head_dim) + v_gen = gen_seq.new_empty(0, self.num_key_value_heads, self.head_dim) packed_query_states_ = from_und_gen_splits(q_und_, q_gen_, pack) # [N_und+N_gen,num_heads,head_dim] packed_key_states_ = from_und_gen_splits(k_und_, k_gen_, pack) # [N_und+N_gen,num_kv_heads,head_dim] @@ -599,9 +605,22 @@ def forward( v_und[:und_len].unsqueeze(0), ) - # Apply projections directly to get final results - und_seq = self.o_proj(get_und_seq(packed_attn_output)) # [N_und,hidden_size] - gen_seq = self.o_proj_moe_gen(get_gen_seq(packed_attn_output)) # [N_gen,hidden_size] + # Apply output projections. True-skip each pathway when its sequence is + # empty so the corresponding ``o_proj`` weight is never invoked (required + # for reasoner/generator CPU offload, where the inactive pathway's weights + # live on CPU). When both are non-empty (joint path) behavior is unchanged. + und_attn_out = get_und_seq(packed_attn_output) # [N_und,heads*head_dim] + gen_attn_out = get_gen_seq(packed_attn_output) # [N_gen,heads*head_dim] + und_seq = ( + self.o_proj(und_attn_out) + if has_und + else und_attn_out.new_empty(0, self.hidden_size) + ) # [N_und,hidden_size] + gen_seq = ( + self.o_proj_moe_gen(gen_attn_out) + if has_gen + else gen_attn_out.new_empty(0, self.hidden_size) + ) # [N_gen,hidden_size] return from_und_gen_splits(und_seq, gen_seq, pack), kv_to_store # [N_und+N_gen,hidden_size] def reasoner_forward( @@ -792,8 +811,9 @@ def _impl_forward( if memory is not None: memory.init(hidden_states, device) - # Derive gen_only once (outside compile) if using MemoryState + # Derive gen_only / und_only once (outside compile) if using MemoryState memory_gen_only = memory.is_gen_only() if memory is not None else False + memory_und_only = memory.is_und_only() if memory is not None else False for i, decoder_layer in enumerate(self.layers): # MemoryState: produce read-only MemoryValue for this layer (outside compile) @@ -806,6 +826,7 @@ def _impl_forward( natten_metadata=None if natten_metadata_list is None else natten_metadata_list[i], memory_value=memory_value, gen_only=memory_gen_only, + und_only=memory_und_only, ) # MemoryState: store K/V produced by this layer (outside compile) @@ -835,8 +856,14 @@ def _impl_forward( ) hidden_states_out = zeros_like(hidden_states) - set_und_seq(hidden_states_out, self.norm(get_und_seq(hidden_states))) # [N_und,hidden_size] - set_gen_seq(hidden_states_out, self.norm_moe_gen(get_gen_seq(hidden_states))) # [N_gen,hidden_size] + # Final norms: skip the offloaded tower's norm in the reasoner/generator-split + # modes (``self.norm`` is a reasoner weight; ``self.norm_moe_gen`` a generator + # weight) so the inactive tower's weights are never invoked. The joint path + # (both flags False) runs both, unchanged. + if not memory_gen_only: + set_und_seq(hidden_states_out, self.norm(get_und_seq(hidden_states))) # [N_und,hidden_size] + if not memory_und_only: + set_gen_seq(hidden_states_out, self.norm_moe_gen(get_gen_seq(hidden_states))) # [N_gen,hidden_size] return hidden_states_out, final_lbl_metadata @@ -916,6 +943,7 @@ def forward( natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, gen_only: bool = False, + und_only: bool = False, ) -> tuple[FactoredSequencePack, dict[str, LBLMetadata], KVToStore | None]: """Forward pass with MoT routing and optional memory-augmented attention. @@ -925,6 +953,18 @@ def forward( ``MemoryState.write_for_layer()`` outside the ``torch.compile`` boundary. + Three mutually exclusive execution branches (``und_only`` and + ``gen_only`` are never both True): + + - ``und_only`` (reasoner prefill): run only the understanding pathway; + the generator pathway is skipped and its tokens pass through + unchanged. The understanding K/V are captured via ``kv_to_store``. + - ``gen_only`` (offloaded denoise step): run only the generation + pathway; the understanding pathway is skipped and its tokens pass + through unchanged (the und K/V come from the prefill cache). + - standard (else): the joint path that runs both pathways. This branch + is numerically identical to the pre-offload implementation. + Args: input: Packed sequence with und/gen tokens attention_mask: Attention mask @@ -932,56 +972,68 @@ def forward( natten_metadata: Optional NATTEN metadata for neighborhood attention. memory_value: Read-only tensor container from MemoryState.read_for_layer(). gen_only: When True, skip the understanding pathway (und K/V come from cache). + und_only: When True, run only the understanding pathway (reasoner prefill). """ - # Pre-Attention layernorm - pack_norm_out = from_und_gen_splits( - self.input_layernorm(get_und_seq(input)), # [N_und,hidden_size] - self.input_layernorm_moe_gen(get_gen_seq(input)), # [N_gen,hidden_size] - input, - ) # [N_und+N_gen,hidden_size] - - # Self Attention + Residual + und_in = get_und_seq(input) # [N_und,hidden_size] + gen_in = get_gen_seq(input) # [N_gen,hidden_size] + lbl_metadata_dict: dict[str, LBLMetadata] = dict() kv_to_store: KVToStore | None = None - if gen_only: - assert natten_metadata is None - # gen_only: skip und, compute gen tokens only (und K/V come from cache) - _gen_norm = get_gen_seq(pack_norm_out) - gen_pack = from_und_gen_splits( - _gen_norm.new_empty(0, _gen_norm.shape[-1]), - _gen_norm, - pack_norm_out, - ) - # Build position embeddings whose und length matches gen_pack's - # und length (always 0). Required when the outer pack carries - # a padded causal_seq (``pad_for_cuda_graphs=True``): without - # this, the und RoPE inside ``PackedAttentionMoT.forward`` - # would broadcast cos/sin of shape ``(MAX_CAUSAL_LEN, head_dim)`` - # onto a length-0 ``q_und`` / ``k_und`` and crash. When the - # outer pack is unpadded (eager AR path), the und cos/sin - # already have length 0 and this slice is a no-op. - _cos, _sin = packed_position_embeddings - _empty_cos_und = get_und_seq(_cos)[:0] - _empty_sin_und = get_und_seq(_sin)[:0] - gen_position_embeddings = ( - from_und_gen_splits(_empty_cos_und, get_gen_seq(_cos), _cos), - from_und_gen_splits(_empty_sin_und, get_gen_seq(_sin), _sin), + # The single-tower packs below carry an empty opposite pathway, so + # ``PackedAttentionMoT`` true-skips it (``has_und`` / ``has_gen``) and never + # reads its RoPE — ``packed_position_embeddings`` can be passed unchanged. + if und_only: + # REASONER PREFILL: run only the understanding pathway; generation tokens + # pass through unchanged so the generator weights stay offloaded. The + # understanding K/V are captured via kv_to_store for the gen-only steps. + und_norm = self.input_layernorm(und_in) # [N_und,hidden_size] + und_pack = from_und_gen_splits(und_norm, und_norm.new_empty(0, und_norm.shape[-1]), input) + + pack_attn_out, kv_to_store = self.self_attn( + und_pack, attention_mask, packed_position_embeddings, natten_metadata=natten_metadata, memory_value=memory_value ) + residual_und = und_in + get_und_seq(pack_attn_out) # [N_und,hidden_size] + + ln_out_und = self.post_attention_layernorm(residual_und) # [N_und,hidden_size] + und_len = pack_attn_out["_num_causal_tokens"] + mlp_out_und_unpadded, lbl_metadata_und = _run_mlp(self.mlp, ln_out_und[:und_len]) + mlp_out_und = torch.cat([mlp_out_und_unpadded, ln_out_und[und_len:]], dim=0) # [N_und,hidden_size] + if lbl_metadata_und is not None: + lbl_metadata_dict["und"] = lbl_metadata_und + + mlp_out_und_seq = residual_und + mlp_out_und # [N_und,hidden_size] + mlp_out_gen_seq = gen_in # passthrough (generation pathway untouched) + elif gen_only: + assert natten_metadata is None + # GENERATOR-ONLY DENOISE STEP: run only the generation pathway; the + # understanding tokens pass through unchanged (their K/V come from the + # prefill cache) so the reasoner weights stay offloaded. + gen_norm = self.input_layernorm_moe_gen(gen_in) # [N_gen,hidden_size] + gen_pack = from_und_gen_splits(gen_norm.new_empty(0, gen_norm.shape[-1]), gen_norm, input) pack_attn_out, kv_to_store = self.self_attn( - gen_pack, - attention_mask, - gen_position_embeddings, - natten_metadata=natten_metadata, - memory_value=memory_value, + gen_pack, attention_mask, packed_position_embeddings, natten_metadata=natten_metadata, memory_value=memory_value ) - gen_attn_out = get_gen_seq(pack_attn_out) - # No residual_und here: the gen_only MLP branch below builds its own - # length-0 und sequence for ``mlp_out_und_seq``; carrying one through - # this branch is dead code. - residual_gen = get_gen_seq(input) + gen_attn_out + residual_gen = gen_in + get_gen_seq(pack_attn_out) # [N_gen,hidden_size] + + ln_out_gen = self.post_attention_layernorm_moe_gen(residual_gen) # [N_gen,hidden_size] + gen_len = pack_attn_out["_num_full_tokens"] + mlp_out_gen_unpadded, lbl_metadata_gen = _run_mlp(self.mlp_moe_gen, ln_out_gen[:gen_len]) + mlp_out_gen = torch.cat([mlp_out_gen_unpadded, ln_out_gen[gen_len:]], dim=0) # [N_gen,hidden_size] + if lbl_metadata_gen is not None: + lbl_metadata_dict["gen"] = lbl_metadata_gen + + mlp_out_und_seq = und_in # passthrough (understanding pathway untouched) + mlp_out_gen_seq = residual_gen + mlp_out_gen # [N_gen,hidden_size] else: - # STANDARD PATH: Process both und and gen tokens + # STANDARD (joint) PATH: process both und and gen tokens. This branch + # is numerically identical to the pre-offload implementation. + pack_norm_out = from_und_gen_splits( + self.input_layernorm(und_in), # [N_und,hidden_size] + self.input_layernorm_moe_gen(gen_in), # [N_gen,hidden_size] + input, + ) # [N_und+N_gen,hidden_size] + pack_attn_out, kv_to_store = self.self_attn( pack_norm_out, attention_mask, @@ -989,42 +1041,13 @@ def forward( natten_metadata=natten_metadata, memory_value=memory_value, ) - residual_und = get_und_seq(input) + get_und_seq(pack_attn_out) # [N_und,hidden_size] - residual_gen = get_gen_seq(input) + get_gen_seq(pack_attn_out) # [N_gen,hidden_size] - - # Pre-MLP layernorm and processing - lbl_metadata_dict: dict[str, LBLMetadata] = dict() - - if gen_only: - # gen_only: skip und, compute gen tokens only - ln_out_und = residual_gen.new_empty(0, residual_gen.shape[-1]) - ln_out_gen = self.post_attention_layernorm_moe_gen(residual_gen) + residual_und = und_in + get_und_seq(pack_attn_out) # [N_und,hidden_size] + residual_gen = gen_in + get_gen_seq(pack_attn_out) # [N_gen,hidden_size] - # UNPAD MLP INPUT (gen only) - gen_len = pack_attn_out["_num_full_tokens"] - ln_out_gen_unpadded = ln_out_gen[:gen_len] # [N_gen_unpadded,hidden_size] - - # Run MLP (gen only) - mlp_out_gen_unpadded, lbl_metadata_gen = _run_mlp(self.mlp_moe_gen, ln_out_gen_unpadded) - # mlp_out_gen_unpadded: [N_gen_unpadded,hidden_size] - - # PAD MLP OUTPUT (gen only) - mlp_out_gen = torch.cat([mlp_out_gen_unpadded, ln_out_gen[gen_len:]], dim=0) # [N_gen,hidden_size] - - # Build metadata dict (no und metadata in optimized path) - if lbl_metadata_gen is not None: - lbl_metadata_dict["gen"] = lbl_metadata_gen - - # Final output with residual (gen only) - mlp_out_und_seq = residual_gen.new_empty(0, residual_gen.shape[-1]) - mlp_out_gen_seq = residual_gen + mlp_out_gen - else: - # STANDARD PATH: Process both und and gen tokens ln_out_und = self.post_attention_layernorm(residual_und) # [N_und,hidden_size] ln_out_gen = self.post_attention_layernorm_moe_gen(residual_gen) # [N_gen,hidden_size] # UNPAD MLP INPUT =============== - # artificial expert inbalance due to routing padding tokens. gen_len = pack_attn_out["_num_full_tokens"] und_len = pack_attn_out["_num_causal_tokens"] diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 7ecb564..0e4e171 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -4,6 +4,7 @@ from __future__ import annotations import collections +import contextvars import time from contextlib import contextmanager from typing import Any, Callable, Dict, Mapping, Optional, Tuple @@ -48,7 +49,7 @@ build_dense_sound_schedule, unwrap_and_densify, ) -from cosmos_framework.model.vfm.utils.memory import MemoryState +from cosmos_framework.model.vfm.utils.memory import MemoryState, ReasonerMemoryState from cosmos_framework.model.vfm.utils.safetensors_loader import load_language_model as load_language_model_safetensors from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import tokenize_caption from cosmos_framework.model.vfm.tokenizers.interface import VideoTokenizerInterface @@ -59,6 +60,66 @@ from cosmos_framework.utils.vfm.parallelism import ParallelDims +# Names of the network offload groups (from ``Cosmos3VFMNetwork.offload_module_groups``) +# whose weights should be materialized directly on CPU at load time (two-phase +# materialization), so they never occupy GPU memory. Set by the inference layer via +# ``cpu_offload_materialization`` around model construction; read inside ``build_net``. +_CPU_OFFLOAD_NET_PARTS: contextvars.ContextVar[tuple[str, ...]] = contextvars.ContextVar( + "_cpu_offload_net_parts", default=() +) + + +@contextmanager +def cpu_offload_materialization(net_parts: tuple[str, ...]): + """Materialize the named network offload groups on CPU during model construction. + + Wrap model loading (``from_pretrained_dcp`` / ``load_model_from_checkpoint``) with + this so the listed groups (e.g. ``("reasoner", "generator")``) are built directly on + CPU and the checkpoint shards load into CPU tensors — they never touch the GPU. Empty + ``net_parts`` is a no-op (default joint materialization, unchanged). + """ + token = _CPU_OFFLOAD_NET_PARTS.set(tuple(net_parts)) + try: + yield + finally: + _CPU_OFFLOAD_NET_PARTS.reset(token) + + +def _offloaded_tensor_ids(net: torch.nn.Module, net_parts: tuple[str, ...]) -> set[int]: + """Collect ``id()`` of every parameter/buffer belonging to the offloaded groups.""" + groups = net.offload_module_groups() + ids: set[int] = set() + for part in net_parts: + for module in groups.get(part, []): + for tensor in (*module.parameters(recurse=True), *module.buffers(recurse=True)): + ids.add(id(tensor)) + return ids + + +def _materialize_meta_tensors(net: torch.nn.Module, device: torch.device, skip_ids: set[int] | None) -> None: + """Materialize every still-``meta`` parameter/buffer on ``device`` (skipping ``skip_ids``). + + Allocates real empty tensors in place. ``skip_ids`` leaves the offloaded tensors on + ``meta`` for the first (GPU) pass so they sidestep the wasted random init, then a + second pass with ``skip_ids=None`` lands them on CPU. + """ + for module in net.modules(): + for name, param in list(module._parameters.items()): + if param is None or not param.is_meta: + continue + if skip_ids is not None and id(param) in skip_ids: + continue + module._parameters[name] = torch.nn.Parameter( + torch.empty_like(param, device=device), requires_grad=param.requires_grad + ) + for name, buffer in list(module._buffers.items()): + if buffer is None or not buffer.is_meta: + continue + if skip_ids is not None and id(buffer) in skip_ids: + continue + module._buffers[name] = torch.empty_like(buffer, device=device) + + class OmniMoTModel(ImaginaireModel): """ Mixture of Transformers (MoT) model to be trained with the flow matching objective @@ -241,12 +302,27 @@ def build_net(self, dtype: torch.dtype): with misc.timer("meta to cuda and broadcast model states"): net = net.to(dtype=dtype) - net.to_empty(device=DEVICE) + cpu_offload_net_parts = _CPU_OFFLOAD_NET_PARTS.get() if DEVICE == Device.CUDA else () + if cpu_offload_net_parts: + # Single-GPU CPU offloading: materialize only the non-offloaded modules on + # the GPU and keep the offloaded towers on ``meta``, so ``init_weights`` + # skips their random init — pure waste here (they come entirely from the + # checkpoint and have no non-persistent buffers) and crippling on CPU. + _materialize_meta_tensors( + net, torch.device("cuda"), skip_ids=_offloaded_tensor_ids(net, cpu_offload_net_parts) + ) + else: + net.to_empty(device=DEVICE) + + # Weight init is only needed on CUDA (CPU/meta are for checkpoint conversion and + # smoke tests). It initializes the non-offloaded modules and their non-persistent + # buffers (e.g. RoPE); the offloaded towers stay on ``meta`` (init no-ops there). if DEVICE == Device.CUDA: - # Weight initialization is not needed for other devices (cpu, - # meta), since they are only for checkpoint conversion and smoke - # tests. net.init_weights(buffer_device=DEVICE) + if cpu_offload_net_parts: + # Land the offloaded towers on CPU; ``dcp.load`` fills them next, so they + # never occupy GPU memory. + _materialize_meta_tensors(net, torch.device("cpu"), skip_ids=None) if getattr(self.config, "lora_enabled", False): self._init_lora_weights_post_materialization(net) @@ -1784,6 +1860,7 @@ def _get_velocity( sequence_plans: list[SequencePlan], gen_data_clean: GenerationDataClean, skip_text_tokens: bool = False, + memory: MemoryState | None = None, ) -> list[torch.Tensor]: """ Compute velocity prediction for a single sampling step. @@ -1906,8 +1983,14 @@ def _get_velocity( fps_vision=gen_data_clean.fps_vision, fps_action=fps_action, fps_sound=fps_sound, + memory=memory, ) + if memory is not None and memory.is_und_only(): + # Reasoner prefill: the network populated the per-layer understanding K/V + # cache as a side effect (no generation output). Nothing to mask/return. + return [] + # --- Apply velocity masks --- # Zero out velocity for conditioned parts (they don't change during sampling) assert packed_sequence.vision is not None, "packed_sequence.vision must exist for velocity masking" @@ -2302,6 +2385,61 @@ def generate_samples_from_batch( _dp_shard_group = None _align_device = None + # --- Reasoner/generator split (single-GPU CPU offloading). Default off. --- + # When enabled (set by the inference layer via ``_reasoner_generator_split``), + # the understanding ("reasoner") pathway is computed once as a prefill that + # caches the per-layer understanding K/V; the diffusion denoise loop then runs + # generator-only and reuses that cache. ``_offload_stage_fn`` (optional, also + # set by the inference layer) stages the reasoner/generator weight groups into + # the single GPU arena around the prefill / denoise; the model itself never + # imports the inference-side offload machinery. + _split_enabled = getattr(self, "_reasoner_generator_split", False) + _stage_fn = getattr(self, "_offload_stage_fn", None) + _mem_cond: ReasonerMemoryState | None = None + _mem_uncond: ReasonerMemoryState | None = None + if _split_enabled: + if self.parallel_dims is not None and ( + self.parallel_dims.cp_enabled + or self.parallel_dims.cfgp_enabled + or (self.parallel_dims.dp_shard_mesh is not None and self.parallel_dims.dp_shard_mesh.size() > 1) + ): + raise NotImplementedError( + "Reasoner/generator-split CPU offloading is single-GPU only " + "(no context-, CFG-, or FSDP-shard parallelism)." + ) + net_obj = net or self.net + num_layers = len(net_obj.language_model.model.layers) + # The reasoner prefill is independent of the diffusion timestep (the + # understanding pathway never reads the timestep embedding), so any value + # works; the noise values are likewise unused (vision encode is skipped). + prefill_timestep = torch.zeros((n_sample, 1)) + + def _prefill_reasoner(tokens: list[list[int]], skip_text_tokens: bool) -> ReasonerMemoryState: + mem = ReasonerMemoryState(num_layers) + mem.set_mode("prefill") + self._get_velocity( + net=net, + noise_x=initial_noise, + timestep=prefill_timestep, + text_tokens=tokens, + sequence_plans=sequence_plans, + gen_data_clean=gen_data_clean, + skip_text_tokens=skip_text_tokens, + memory=mem, + ) + mem.set_mode("gen") + return mem + + if _stage_fn is not None: + _stage_fn("reasoner") + _mem_cond = _prefill_reasoner(cond_tokens, skip_text_tokens=False) + if guidance != 1.0: + _mem_uncond = _prefill_reasoner(uncond_tokens, skip_text_tokens=skip_text_tokens_for_cfg) + + # Stage the generator for the whole denoise loop (reasoner now offloaded). + if _stage_fn is not None: + _stage_fn("generator") + # Create a velocity function for a single sample (for use with self.sampler). def velocity_fn(noise_x: list[torch.Tensor], timestep: torch.Tensor) -> list[torch.Tensor]: @@ -2316,6 +2454,12 @@ def velocity_fn(noise_x: list[torch.Tensor], timestep: torch.Tensor) -> list[tor timestep = timestep.repeat(len(noise_x), 1) def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): + # In the reasoner/generator split, route each CFG branch to its own + # prefilled understanding K/V cache (cond vs uncond). Identity check on + # the token-list object distinguishes the branches. + _mem = None + if _split_enabled: + _mem = _mem_cond if tokens is cond_tokens else _mem_uncond return self._get_velocity( net=net, noise_x=noise_x, @@ -2324,6 +2468,7 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): sequence_plans=sequence_plans, gen_data_clean=gen_data_clean, skip_text_tokens=skip_text_tokens, + memory=_mem, ) # Local CFG decision for THIS rank, honoring guidance_interval. @@ -2717,6 +2862,11 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration: self._normalize_video_databatch_inplace(data_batch) self._augment_image_dim_inplace(data_batch) # converts each image tensor to (1, C, 1, H, W) raw_state_vision = data_batch[self.input_image_key if is_image_batch else self.input_video_key] + # Stage the VAE onto the GPU arena for conditioning encode when it is offloaded + # (no-op otherwise, and skipped entirely when there is nothing to encode, e.g. t2v). + _stage_fn = getattr(self, "_offload_stage_fn", None) + if _stage_fn is not None and len(raw_state_vision) > 0: + _stage_fn("vae") x0_tokens_vision = [ self.encode(raw_state_vision_i).contiguous().float() for raw_state_vision_i in raw_state_vision ] @@ -3405,6 +3555,10 @@ def denoise( fps_sound=fps_sound, memory=memory, ) + if memory is not None and memory.is_und_only(): + # Reasoner prefill: the network only populated the understanding K/V cache + # (side effect on ``memory``) and returned no generation predictions. + return dict() output_dict = dict() output_dict["preds_vision"] = out_net["preds_vision"] if self.config.action_gen and "preds_action" in out_net: diff --git a/cosmos_framework/model/vfm/utils/memory.py b/cosmos_framework/model/vfm/utils/memory.py index 0487882..8392a69 100644 --- a/cosmos_framework/model/vfm/utils/memory.py +++ b/cosmos_framework/model/vfm/utils/memory.py @@ -93,6 +93,16 @@ def is_gen_only(self) -> bool: Used for autoregressive frame-by-frame generation of video. """ + def is_und_only(self) -> bool: + """Return ``True`` when only the understanding (reasoner) pathway should run. + + When ``True``, the decoder layer runs the reasoner prefill: it computes + the understanding pathway over the text tokens and caches the per-layer + K/V, skipping the generation pathway entirely. Defaults to ``False`` so + existing memory states (which only toggle gen-only) are unaffected. + """ + return False + @property def uses_rolling_kv_cache(self) -> bool: """Whether this memory uses the rolling KV-cache / compile-safe path. @@ -101,3 +111,87 @@ def uses_rolling_kv_cache(self) -> bool: temporal causality is handled inside three-way attention instead. """ return False + + +@dataclass +class ReasonerMemoryValue(MemoryValue): + """Read-only per-layer understanding (reasoner) K/V snapshot. + + Carries the post-RoPE understanding-pathway keys/values cached during the + one-time reasoner prefill, so the generator-only denoise pass can attend to + them without recomputing the understanding pathway. Shapes follow the + ``KVToStore`` contract produced by ``PackedAttentionMoT.forward``: + ``[1, und_len, num_kv_heads, head_dim]``. + """ + + # ``None`` only during the prefill pass (cache not yet populated); the + # generator-only attention path asserts these are present before use. + und_k: torch.Tensor | None + und_v: torch.Tensor | None + + +class ReasonerMemoryState(MemoryState): + """Per-layer understanding-K/V cache for the reasoner/generator split. + + Single-sample inference only (the offloaded single-GPU path). Drives a + one-time understanding prefill followed by generator-only denoise steps via + a three-valued ``mode``: + + - ``"prefill"`` (``is_und_only()`` is ``True``): each decoder layer runs the + reasoner pathway only and ``write_for_layer`` stores the understanding K/V. + - ``"gen"`` (``is_gen_only()`` is ``True``): each decoder layer runs the + generation pathway only and ``read_for_layer`` returns the cached + understanding K/V for the gen->und cross-attention. + - ``None`` (default): both ``is_und_only()`` and ``is_gen_only()`` are + ``False`` so the decoder runs the joint forward. + + The understanding tokens are the (fixed) text prompt and the understanding + pathway attends causally over understanding tokens only, so its per-layer + K/V are independent of the diffusion timestep and are valid across all + denoise steps. + """ + + _VALID_MODES = (None, "prefill", "gen") + + def __init__(self, num_layers: int) -> None: + self._num_layers = num_layers + self._und_k: list[torch.Tensor | None] = [None] * num_layers + self._und_v: list[torch.Tensor | None] = [None] * num_layers + self._mode: str | None = None + + def init(self, hidden_states: dict, device: torch.device) -> None: # noqa: D401 - see base + # Nothing to allocate up front; entries are filled by write_for_layer. + return None + + def set_mode(self, mode: str | None) -> None: + if mode not in self._VALID_MODES: + raise ValueError(f"Invalid ReasonerMemoryState mode {mode!r}; expected one of {self._VALID_MODES}.") + self._mode = mode + + def is_gen_only(self) -> bool: + return self._mode == "gen" + + def is_und_only(self) -> bool: + return self._mode == "prefill" + + @property + def is_initialized(self) -> bool: + return all(k is not None for k in self._und_k) + + def read_for_layer(self, layer_idx: int) -> ReasonerMemoryValue: + # Never raises: the decoder loop calls this on every layer, including + # during the prefill pass when the cache is still empty. During prefill + # the returned (None) K/V are ignored by the attention dispatch (the + # generation sequence is empty, so the und-only prefill path runs); only + # the generator-only path consumes the cached K/V (and asserts they exist). + return ReasonerMemoryValue(und_k=self._und_k[layer_idx], und_v=self._und_v[layer_idx]) + + def write_for_layer(self, layer_idx: int, kv_to_store: KVToStore) -> None: + # kv_to_store == (gen_k, gen_v, und_k, und_v); only the understanding + # K/V are persisted, and only during the prefill pass. On generator-only + # steps the understanding sequence is empty, so skip (keep the cache). + if self._mode != "prefill": + return + _, _, und_k, und_v = kv_to_store + self._und_k[layer_idx] = und_k + self._und_v[layer_idx] = und_v diff --git a/docs/faq.md b/docs/faq.md index 4ddf47e..a1f5f4a 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -165,9 +165,10 @@ Try these in order: export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True ``` -2. **Increase `--dp-shard-size`** to shard model weights across more GPUs via FSDP. Inference auto-picks a value that fits the model at ~75% device memory (see `_get_dp_shard_size` in `cosmos_framework/inference/args.py`); passing a larger explicit value drops per-GPU memory at the cost of more all-gather traffic. Requires multi-GPU. -3. **Lower `--device-memory-utilization`** (default `0.75`). The auto-`dp_shard_size` formula is `ceil(model_memory / device_memory / utilization)`, so passing e.g. `--device-memory-utilization=0.5` forces auto-mode to pick a larger `dp_shard_size` and leaves more per-GPU headroom for activations / KV cache. Requires multi-GPU. -4. **Add `--offload-guardrail-models`** to move the text and video guardrail models to CPU. Frees the GPU memory they would otherwise hold for the full run, at the cost of some extra latency when guardrails are invoked. +2. **(Single-GPU) Add `--offload-stages reasoner generator vae`** to offload the transformer towers and VAE to pinned CPU memory, staging each back onto one reusable GPU arena only while in use. This is the biggest single-GPU lever (e.g. ~13 GiB lower peak for Cosmos3-Nano `text2video`). Incompatible with CUDA graphs. See [inference.md § CPU Offloading](./inference.md#cpu-offloading-single-gpu). +3. **Increase `--dp-shard-size`** to shard model weights across more GPUs via FSDP. Inference auto-picks a value that fits the model at ~75% device memory (see `_get_dp_shard_size` in `cosmos_framework/inference/args.py`); passing a larger explicit value drops per-GPU memory at the cost of more all-gather traffic. Requires multi-GPU. +4. **Lower `--device-memory-utilization`** (default `0.75`). The auto-`dp_shard_size` formula is `ceil(model_memory / device_memory / utilization)`, so passing e.g. `--device-memory-utilization=0.5` forces auto-mode to pick a larger `dp_shard_size` and leaves more per-GPU headroom for activations / KV cache. Requires multi-GPU. +5. **Add `--offload-guardrail-models`** to move the text and video guardrail models to CPU. Frees the GPU memory they would otherwise hold for the full run, at the cost of some extra latency when guardrails are invoked. See [inference.md#torch-cuda-out-of-memory-error](./inference.md#torch-cuda-out-of-memory-error) for the full troubleshooting section. diff --git a/docs/inference.md b/docs/inference.md index 1580147..687c064 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -16,6 +16,7 @@ ______________________________________________________________________ - [Models](#models) - [Modes](#modes) - [Parallelism Arguments](#parallelism-arguments) +- [CPU Offloading (single-GPU)](#cpu-offloading-single-gpu) - [Sample Arguments](#sample-arguments) - [Text](#text) - [Vision (Image/Video)](#vision-imagevideo) @@ -157,6 +158,39 @@ By default the model weights are sharded (FSDP) across **all** visible GPUs (`dp - `--dp-shard-size`: Number of ranks the model is sharded over (FSDP). Defaults to all ranks (`WORLD_SIZE`). - `--max-num-seqs`: Maximum number of samples batched together per replica. +## CPU Offloading (single-GPU) + +On a single GPU, the largest model components can be offloaded to pinned CPU memory and staged back onto **one** reusable GPU "arena" only while they are in use, cutting peak GPU memory at a small latency cost. This lets larger resolutions / longer videos (and smaller-VRAM GPUs) fit. Enable it with `--offload-stages`, naming the components to offload (space-separated): + +```shell +python -m cosmos_framework.scripts.inference \ + --parallelism-preset=latency \ + -i "inputs/omni/t2v.json" \ + -o outputs/omni_nano \ + --checkpoint-path Cosmos3-Nano \ + --offload-stages reasoner generator vae \ + --seed=0 +``` + +| Stage | Offloaded component | On GPU only during | +| ----------- | ------------------------------------------------------------ | ---------------------------- | +| `reasoner` | Understanding (reasoner) tower of the MoT | the reasoner prefill | +| `generator` | Generation (diffusion) tower + diffusion encode/decode heads | the denoise loop | +| `vae` | Vision tokenizer (Wan2.2 VAE) | conditioning encode + decode | + +(Guardrail offloading is controlled separately by [`--offload-guardrail-models`](#guardrails).) + +**How the reasoner/generator split works.** When `reasoner` and/or `generator` are offloaded, the understanding pathway is computed **once** as a prefill that caches the per-layer understanding K/V; the diffusion denoise loop then runs **generator-only**, reusing that cache, so the reasoner weights stay on CPU throughout denoising. The two towers (~half the transformer each) time-share one GPU arena sized to the larger tower, so peak transformer-weight residency is roughly halved. `vae` shares the same arena (staged around encode/decode). + +The split is numerically equivalent to the standard joint path — bit-identical without `torch.compile`, and within bf16 tolerance with it. Default off (no `--offload-stages`) leaves the joint path unchanged. + +**Measured** (Cosmos3-Nano, single GH200, `text2video` at 256p / 24 frames): peak GPU memory dropped from **34.0 GiB to 21.1 GiB** (≈13 GiB / ≈38% lower) with `--offload-stages reasoner generator`; adding `vae` lowers it further. + +Constraints: + +- **Single-GPU only** — incompatible with multi-GPU sharding / context / CFG parallelism (`dp_shard_size * cp_size * cfgp_size` must be `1`). +- **Incompatible with CUDA graphs** (`--use-cuda-graphs`): staging rebinds weight tensors between calls, which breaks captured static addresses. + ## Sample Arguments Sample arguments are read from multiple sources (in priority order): @@ -248,6 +282,8 @@ Error: `torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate X MiB export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True ``` +On a single GPU, [CPU Offloading](#cpu-offloading-single-gpu) (`--offload-stages reasoner generator vae`) is the most effective lever — it offloads the transformer towers and VAE to CPU and stages them back only while in use. + If that's not enough, see [FAQ § OOM during inference](./faq.md#q-i-get-torchcudaoutofmemoryerror-during-inference) for the full ladder (`--dp-shard-size`, `--device-memory-utilization`, `--offload-guardrail-models`). ### NCCL Issue