diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index 44dcf76..5f27e7f 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -41,6 +41,27 @@ from cosmos_framework.inference.common.inference import Inference +def _load_transfer_prompt_path(path: str | Path) -> str: + """Load a transfer prompt from a ``.json`` or plain ``.txt`` file.""" + resolved = Path(path) + text = resolved.read_text() + if resolved.suffix.lower() == ".json": + return json.dumps(json.loads(text)) + return text.strip() + + +def _load_transfer_negative_prompt_file(path: str | Path) -> str: + """Load a JSON negative caption file for transfer inference.""" + candidate = Path(path) + if not candidate.is_file(): + defaults_path = PACKAGE_DIR / "defaults" / candidate.name + if defaults_path.is_file(): + candidate = defaults_path + else: + raise FileNotFoundError(f"Missing negative prompt file: {path} (also checked {defaults_path})") + return json.dumps(json.loads(candidate.read_text())) + + @cache def _load_modality_defaults(model_mode: str) -> dict[str, Any]: default_file = PACKAGE_DIR / f"defaults/{model_mode}/sample_args.json" @@ -59,12 +80,15 @@ def _load_modality_defaults(model_mode: str) -> dict[str, Any]: Guidance = Annotated[float, pydantic.Field(ge=0, le=7)] GuidanceInterval = tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat] PromptUpsamplerProbability = Annotated[float, pydantic.Field(ge=0, le=1)] +ControlGuidance = Annotated[float, pydantic.Field(ge=0, le=10)] class SamplingArgs(ArgsBase): num_steps: pydantic.PositiveInt guidance: Guidance guidance_interval: GuidanceInterval | None + control_guidance: ControlGuidance = 1.0 + control_guidance_interval: GuidanceInterval | None = None normalize_cfg: bool shift: float sigma_max: float @@ -79,6 +103,13 @@ class SamplingOverrides(OverridesBase): """Guidance scale for the diffusion model.""" guidance_interval: Training[GuidanceInterval | None] = None """Guidance interval for the diffusion model.""" + control_guidance: Training[ControlGuidance | None] = None + """Control-CFG scale for transfer. The control map stays in the main forward; 1.0 disables + the extra comparison forward that drops control-map vision items. Values > 1.0 blend + velocities from with-maps vs without-maps forwards on the generated clip.""" + control_guidance_interval: Training[GuidanceInterval | None] = None + """Timestep interval [lo, hi] (0–1000) in which control-CFG is applied; None applies + on every step.""" normalize_cfg: Training[bool | None] = None """If True, normalize the CFG output.""" shift: Training[float | None] = None @@ -102,6 +133,8 @@ def _build_sampling(self, model_config: "OmniMoTModelConfig", sample_meta: "Samp self.sigma_max = 0.0 return assert self.num_steps is not None + if self.control_guidance is None: + self.control_guidance = 1.0 if SMOKE: self.num_steps = min(self.num_steps, 1) @@ -300,7 +333,11 @@ def _build_text_data(self, model_config: "OmniMoTModelConfig", sample_meta: Samp if self.prompt is not None: pass elif self.prompt_path is not None: - self.prompt = self.prompt_path.read_text().strip() + transfer_self = cast("_TransferDataBase", self) + if transfer_self.transfer_hints: + self.prompt = _load_transfer_prompt_path(self.prompt_path) + else: + self.prompt = self.prompt_path.read_text().strip() else: self.prompt = "" @@ -612,6 +649,136 @@ def _build_reasoner_data(self, model_config: "OmniMoTModelConfig", sample_meta: raise ValueError("Reasoner inference requires a non-empty 'prompt'.") +class _TransferDataBase: + @property + def transfer_hints(self) -> dict[TransferHintKey, TransferOverrides | TransferArgs]: + # Iteration order is `TransferHintKey` enum order, not JSON-key order — keep this + # deterministic so the model sees a stable [ctrl_1, ..., ctrl_N] sequence. + return {key: getattr(self, key.value) for key in TransferHintKey if getattr(self, key.value) is not None} + + +class TransferDataArgs(ArgsBase, _TransferDataBase): + edge: EdgeTransferArgs | None = None + blur: BlurTransferArgs | None = None + depth: TransferArgs | None = None + seg: TransferArgs | None = None + wsm: TransferArgs | None = None + negative_prompt_file: str | None = None + """JSON negative caption file for transfer specs (absolute path or filename under ``defaults/``).""" + num_video_frames_per_chunk: pydantic.PositiveInt | None = None + num_conditional_frames: pydantic.NonNegativeInt | None = None + max_frames: pydantic.PositiveInt | None = None + show_control_condition: bool | None = None + show_input: bool | None = None + num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None + share_vision_temporal_positions: bool | None = None + + +class TransferDataOverrides(OverridesBase, _TransferDataBase): + """Transfer inference overrides — activated when at least one control hint is set.""" + + edge: EdgeTransferOverrides | None = None + blur: BlurTransferOverrides | None = None + depth: TransferOverrides | None = None + seg: TransferOverrides | None = None + wsm: TransferOverrides | None = None + negative_prompt_file: str | None = None + """JSON negative caption file for transfer specs (absolute path or filename under ``defaults/``).""" + num_video_frames_per_chunk: pydantic.PositiveInt | None = None + num_conditional_frames: pydantic.NonNegativeInt | None = None + max_frames: pydantic.PositiveInt | None = None + show_control_condition: bool | None = None + show_input: bool | None = None + num_first_chunk_conditional_frames: pydantic.NonNegativeInt | None = None + share_vision_temporal_positions: bool | None = None + + @pydantic.model_validator(mode="after") + def _validate_transfer_hints(self) -> Self: + hint_field_names = {k.value for k in TransferHintKey} + transfer_only = [ + name + for name in TransferDataOverrides.__annotations__ + if name in type(self).model_fields and name not in hint_field_names + ] + if any(getattr(self, f) is not None for f in transfer_only) and not self.transfer_hints: + raise ValueError( + f"transfer inference requires at least one control hint ({', '.join(k.value for k in TransferHintKey)})" + ) + return self + + @override + def download(self, output_dir: Path): + super().download(output_dir) + for config in self.transfer_hints.values(): + assert isinstance(config, TransferOverrides) + config.download(output_dir) + + _TRANSFER_SAMPLE_DEFAULTS: ClassVar[dict[str, Any]] = { + "num_video_frames_per_chunk": 93, + "num_conditional_frames": 1, + "max_frames": 5000, + "show_control_condition": False, + "show_input": False, + "num_first_chunk_conditional_frames": 0, + "share_vision_temporal_positions": True, + } + _TRANSFER_HINT_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = { + TransferHintKey.EDGE: {"preset_edge_threshold": PresetEdgeThreshold.MEDIUM}, + TransferHintKey.BLUR: {"preset_blur_strength": PresetBlurStrength.MEDIUM}, + } + # Tuned guidance / control_guidance per transfer task. Applied when the input JSON omits + # these fields (generic video2video sampling defaults otherwise apply). + _TRANSFER_DEFAULTS: ClassVar[dict[TransferHintKey, dict[str, Any]]] = { + TransferHintKey.EDGE: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0}, + TransferHintKey.BLUR: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0}, + TransferHintKey.DEPTH: {"guidance": 3.0, "control_guidance": 1.5, "shift": 10.0}, + TransferHintKey.SEG: {"guidance": 3.0, "control_guidance": 2.0, "shift": 10.0}, + TransferHintKey.WSM: { + "guidance": 1.0, + "control_guidance": 3.0, + "shift": 10.0, + "num_frames": 101, + "fps": 10, + "num_video_frames_per_chunk": 101, + }, + } + + def _build_transfer_data( + self, + model_config: "OmniMoTModelConfig", + sample_meta: SampleMeta, + *, + user_fields: frozenset[str] | None = None, + ): + self = cast("SampleDataOverrides", self) + hints = self.transfer_hints + if not hints: + return + + if self.negative_prompt is None and self.negative_prompt_file is not None: + self.negative_prompt = _load_transfer_negative_prompt_file(self.negative_prompt_file) + + for field, default in self._TRANSFER_SAMPLE_DEFAULTS.items(): + if getattr(self, field) is None: + setattr(self, field, default) + + for hint_key, config in hints.items(): + for field, default in self._TRANSFER_HINT_DEFAULTS.get(hint_key, {}).items(): + if getattr(config, field) is None: + setattr(config, field, default) + + if self.vision_path is None and config.control_path is None: + raise ValueError( + f"transfer inference requires 'vision_path' or a pre-computed 'control_path' (hint: {hint_key})" + ) + + if len(hints) == 1: + hint_key = next(iter(hints)) + for field, value in self._TRANSFER_DEFAULTS[hint_key].items(): + if user_fields is None or field not in user_fields: + setattr(self, field, value) + + class _SampleDataBase: @property def resolved_model_mode(self) -> ModelMode: @@ -641,6 +808,7 @@ class SampleDataArgs( SoundDataArgs, ActionDataArgs, ReasonerDataArgs, + TransferDataArgs, ): model_mode: ModelMode @@ -652,6 +820,7 @@ class SampleDataOverrides( SoundDataOverrides, ActionDataOverrides, ReasonerDataOverrides, + TransferDataOverrides, ): """Sample data arguments for 'OmniMoTModel.generate_samples'.""" @@ -790,7 +959,8 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs: else: defaults = _load_modality_defaults(sample_meta.model_mode) overrides = self.model_dump(exclude_none=True) - shift_configured = "shift" in overrides or defaults.get("shift") is not None + user_fields = frozenset(overrides) + shift_configured = "shift" in user_fields or defaults.get("shift") is not None merged_data = _deep_merge(defaults, overrides) merged_data = {k: v for k, v in merged_data.items() if k in type(self).model_fields} merged = type(self).model_validate(merged_data) @@ -811,6 +981,10 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs: self._build_reasoner_data(model_config=model_config, sample_meta=sample_meta) + self._build_transfer_data( + model_config=model_config, sample_meta=sample_meta, user_fields=user_fields + ) + if not shift_configured and not sample_meta.model_mode.is_reasoner: model_size = self._VLM_MODEL_SIZE[model_config.vlm_config.model_name] key = (model_size, self.resolution) @@ -880,8 +1054,6 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs: repository="nvidia/Cosmos3-Super-Image2Video", revision="main", ), - # Self-contained checkpoint: use its bundled processor instead of - # downloading the base Cosmos3-Super repo just for the tokenizer. vlm_processor_from_checkpoint=True, ), "Cosmos3-Super-Text2Image": CheckpointConfig( @@ -892,8 +1064,6 @@ def build_sample(self, *, model_config: Any) -> OmniSampleArgs: repository="nvidia/Cosmos3-Super-Text2Image", revision="main", ), - # Self-contained checkpoint: use its bundled processor instead of - # downloading the base Cosmos3-Super repo just for the tokenizer. vlm_processor_from_checkpoint=True, ), } diff --git a/cosmos_framework/inference/inference.py b/cosmos_framework/inference/inference.py index 1190889..9acb4b3 100644 --- a/cosmos_framework/inference/inference.py +++ b/cosmos_framework/inference/inference.py @@ -507,6 +507,9 @@ def get_sample_data( if sample_args.model_mode.is_reasoner: return _get_reasoner_sample_data(sample_args, model) + if sample_args.transfer_hints: + return {} + if sample_args.model_mode.is_action: from cosmos_framework.inference.action import get_action_sample_data @@ -1051,9 +1054,6 @@ def _create(cls, setup_args: SetupArgs, **kwargs: Any) -> Self: else: model_dict = setup_args.load_model_config_dict() if setup_args.vlm_processor_from_checkpoint: - # Source the VLM processor from the loaded checkpoint's own - # bundled files instead of the repository hardcoded in the - # model config. Drops the redundant base-model download. tokenizer_cfg = model_dict["config"]["vlm_config"]["tokenizer"] tokenizer_cfg.pop("repository", None) tokenizer_cfg.pop("revision", None) @@ -1244,13 +1244,6 @@ def create_batches( # --- Phase 4: pad with dummy batches so every replica calls # generate_batch the same number of times (prevents collective # deadlocks in context-parallel / CFG-parallel communication). - # Minimal-cost padding sample: the dummy batch only exists to keep the - # generate_batch call count aligned across replicas, and its output is - # discarded (output_dir=None). Force num_steps=1 / guidance=1.0 so it never - # raises the per-iteration align_num_steps MAX (which would make the dummy - # *and* real samples on peer ranks pad up). The per-step alignment still - # pads this dummy up to MAX(real samples), so collective alignment holds; - # we just stop inflating that MAX with the (arbitrary) global sample[0]. dummy_sa = sample_args_list[0].model_copy( update={"output_dir": None, "name": "padding", "num_steps": 1, "guidance": 1.0} ) @@ -1267,6 +1260,12 @@ def generate_batch( ) -> list[SampleOutputs]: assert all(isinstance(sa, OmniSampleArgs) for sa in sample_args_list) + transfer_flags = [bool(sa.transfer_hints) for sa in sample_args_list] + if any(transfer_flags): + assert all(transfer_flags), "Cannot mix transfer and non-transfer samples in a batch" + assert len(sample_args_list) == 1, "Batching is not supported for transfer inference" + return self._generate_transfer_batch(sample_args_list[0], warmup=warmup) + reasoner_flags = [cast(OmniSampleArgs, sa).model_mode.is_reasoner for sa in sample_args_list] if any(reasoner_flags): assert all(reasoner_flags), "Cannot mix reasoner and non-reasoner samples in a batch" @@ -1391,66 +1390,16 @@ def decode_vision(vision_latent: torch.Tensor) -> torch.Tensor: ) upsample_task = next(iter(distinct_upsample_tasks)) - # FSDP collective-sequence alignment (throughput-style inference where - # ranks hold different samples). Each per-step model forward issues a - # param all-gather over the FSDP-shard (dp_shard) group, so if dp_shard - # peers disagree on ``num_steps`` that group's collective stream - # desyncs and deadlocks NCCL at the watchdog timeout (observed: rank0 - # wedged at step 31/50 the instant its dp_shard peer finished 35). - # - # all_reduce(MAX) the local num_steps over the *dp_shard group* and - # pass it as ``align_num_steps``; ranks below the max pad with - # discarded dummy steps in generate_samples_from_batch. Scope = the - # dp_shard group (not world), because that keeps the reduction within a - # single modality: modality must already be homogeneous within any - # per-forward collective group (else the forward itself desyncs), and - # reasoner-only batches take an early return below and never reach this - # collective — a world reduction would deadlock against them. The - # per-step CP / CFGP collectives are also covered: cp/cfgp groups - # always sit inside one data-parallel replica (replica_id = - # rank // (cp*cfgp)), so when dp_shard and the replica block (cp*cfgp) - # nest, every cp/cfgp peer lands in a dp_shard group with the same MAX. - # The nesting precondition is asserted just below. - local_num_steps = _getattr(sample_args_list, "num_steps") - align_num_steps = local_num_steps - parallel_dims = getattr(self.model, "parallel_dims", None) - if ( - parallel_dims is not None - and parallel_dims.dp_shard_mesh is not None - and torch.distributed.is_initialized() - and parallel_dims.dp_shard_mesh.size() > 1 - ): - # Non-nesting CP/CFGP overlays (neither dp_shard nor the cp*cfgp - # replica block divides the other) let a cp/cfgp group straddle two - # dp_shard groups with different maxima, which a dp_shard-scoped - # reduction cannot align. Both presets nest (throughput: cp=cfgp=1; - # latency: single replica), so this only guards hand-built layouts. - replica_block = parallel_dims.cp * parallel_dims.cfgp - dp_shard_sz = parallel_dims.dp_shard - if replica_block > 1 and dp_shard_sz % replica_block != 0 and replica_block % dp_shard_sz != 0: - raise NotImplementedError( - "num_steps collective alignment requires dp_shard " - f"({dp_shard_sz}) and cp*cfgp ({replica_block}) to nest " - "(one must divide the other). Non-nesting CP/CFGP overlays " - "with divergent per-sample num_steps are unsupported." - ) - _steps_t = torch.tensor( - [local_num_steps], device=self.model.tensor_kwargs["device"], dtype=torch.int32 - ) - torch.distributed.all_reduce( - _steps_t, op=torch.distributed.ReduceOp.MAX, group=parallel_dims.dp_shard_mesh.get_group() - ) - align_num_steps = int(_steps_t.item()) - with self._get_timer(f"{self.model.__class__.__name__}.generate_samples_from_batch"): outputs = self.model.generate_samples_from_batch( data_batch, sampler=sampler, guidance=guidance, guidance_interval=_getattr(sample_args_list, "guidance_interval"), + control_guidance=_getattr(sample_args_list, "control_guidance"), + control_guidance_interval=_getattr(sample_args_list, "control_guidance_interval"), seed=seed, - num_steps=local_num_steps, - align_num_steps=align_num_steps, + num_steps=_getattr(sample_args_list, "num_steps"), shift=_getattr(sample_args_list, "shift"), sigma_max=_getattr(sample_args_list, "sigma_max"), has_negative_prompt=neg_key in data_batch, @@ -1563,6 +1512,74 @@ def decode_vision(vision_latent: torch.Tensor) -> torch.Tensor: return sample_outputs + @torch.no_grad() + def _generate_transfer_batch(self, sample_args: OmniSampleArgs, *, warmup: bool = False) -> list[SampleOutputs]: + """Handle transfer inference using the autoregressive generate_transfer_sample path.""" + from cosmos_framework.inference.transfer import generate_transfer_sample + + try: + with sync_distributed_errors(): + if self.should_process_sample(sample_args) and not warmup: + log.debug(f"{sample_args.__class__.__name__}({sample_args})") + assert sample_args.output_dir is not None + sample_args.output_dir.mkdir(parents=True, exist_ok=True) + sample_args_file = sample_args.output_dir / "sample_args.json" + sample_args_file.write_text(sample_args.model_dump_json()) + log.info(f"Saved sample args to '{sample_args_file}'", rank0_only=False) + except Exception as e: + if self.should_process_sample(sample_args) and not warmup: + return [self._handle_sample_exception(sample_args, e)] + return [] + + transfer_output = generate_transfer_sample(sample_args=sample_args, model=self.model) + + if warmup: + return [] + + sample_outputs: list[SampleOutputs] = [] + try: + with sync_distributed_errors(): + if self.should_process_sample(sample_args): + assert sample_args.output_dir is not None + content: dict[str, Any] = {} + files: list[Path] = [] + + vision_cthw = ((1.0 + transfer_output.output_video.squeeze(0)) / 2).clamp(0, 1) + + if vision_cthw.shape[1] == 1: + quality = sample_args.image_save_quality + else: + quality = sample_args.video_save_quality + vision_file = sample_args.output_dir / f"vision{sample_args.vision_extension}" + output_fps = transfer_output.fps + save_img_or_video(vision_cthw, str(vision_file.with_suffix("")), fps=output_fps, quality=quality) + assert vision_file.is_file(), vision_file + files.append(vision_file) + + for hint_key, control_video in transfer_output.control_videos.items(): + control_cthw = ((1.0 + control_video.squeeze(0)) / 2).clamp(0, 1) + control_file = sample_args.output_dir / f"control_{hint_key}{sample_args.vision_extension}" + save_img_or_video( + control_cthw, str(control_file.with_suffix("")), fps=output_fps, quality=quality + ) + files.append(control_file) + log.info(f"Saved control video to '{control_file}'", rank0_only=False) + + sample_output = SampleOutputs( + args=sample_args.model_dump(mode="json"), + outputs=[SampleOutput(content=content, files=files)], + ) + sample_outputs_file = sample_args.output_dir / "sample_outputs.json" + sample_outputs_file.write_text(sample_output.model_dump_json()) + log.success(f"Saved transfer outputs to '{sample_outputs_file}'", rank0_only=False) + + sample_outputs.append(sample_output) + + except Exception as e: + return [self._handle_sample_exception(sample_args, e)] if self.should_process_sample(sample_args) else [] + + return sample_outputs + @torch.no_grad() def _generate_reasoner_batch( self, diff --git a/cosmos_framework/inference/transfer.py b/cosmos_framework/inference/transfer.py index 4eeefa7..d2bc182 100644 --- a/cosmos_framework/inference/transfer.py +++ b/cosmos_framework/inference/transfer.py @@ -351,6 +351,8 @@ def generate_transfer_sample( sampler=sampler, guidance=guidance, guidance_interval=sample_args.guidance_interval, + control_guidance=sample_args.control_guidance, + control_guidance_interval=sample_args.control_guidance_interval, seed=[seed + chunk_id], n_sample=1, has_negative_prompt=negative_prompt is not None, diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 7ecb564..4f4260f 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -2105,6 +2105,76 @@ def _run_classifier_free_guidance( else: return other_v_list, v_list + def _build_no_control_inference_state( + self, + sequence_plans: list[SequencePlan], + gen_data_clean: GenerationDataClean, + ) -> tuple[list[SequencePlan], GenerationDataClean, list[int]] | None: + """Build inference state without control-map vision (for control-CFG). + + Transfer packs [control_map(s), target_clip] per sample. The no-control branch + drops the control maps from the vision sequence; the text caption and target + clip remain. Returns None when every sample has at most one vision item. + + Also returns ``ctrl_dims_per_sample``: flattened control-token width per sample, + used to slice ``noise_x`` and blend velocities on the target suffix. + """ + num_items_per_sample = gen_data_clean.num_vision_items_per_sample + if num_items_per_sample is None or all(n <= 1 for n in num_items_per_sample): + return None + + assert gen_data_clean.x0_tokens_vision is not None + + new_x0_tokens_vision: list[torch.Tensor] = [] + new_raw_state_vision: list[torch.Tensor] | None = [] if gen_data_clean.raw_state_vision is not None else None + ctrl_dims_per_sample: list[int] = [] + vis_offset = 0 + for n_vis in num_items_per_sample: + ctrl_dim_i = 0 + for j in range(n_vis - 1): + sh = gen_data_clean.x0_tokens_vision[vis_offset + j].shape + ctrl_dim_i += int(torch.tensor(list(sh)).prod().item()) + ctrl_dims_per_sample.append(ctrl_dim_i) + tgt_idx = vis_offset + n_vis - 1 + new_x0_tokens_vision.append(gen_data_clean.x0_tokens_vision[tgt_idx]) + if new_raw_state_vision is not None: + new_raw_state_vision.append(gen_data_clean.raw_state_vision[tgt_idx]) # type: ignore[index] + vis_offset += n_vis + + gdc_nc = GenerationDataClean( + batch_size=gen_data_clean.batch_size, + is_image_batch=gen_data_clean.is_image_batch, + raw_state_vision=new_raw_state_vision, + x0_tokens_vision=new_x0_tokens_vision, + fps_vision=gen_data_clean.fps_vision, + num_vision_items_per_sample=None, + raw_state_action=gen_data_clean.raw_state_action, + x0_tokens_action=gen_data_clean.x0_tokens_action, + action_domain_id=gen_data_clean.action_domain_id, + fps_action=gen_data_clean.fps_action, + raw_action_dim=gen_data_clean.raw_action_dim, + raw_state_sound=gen_data_clean.raw_state_sound, + x0_tokens_sound=gen_data_clean.x0_tokens_sound, + fps_sound=gen_data_clean.fps_sound, + ) + + sp_nc = [ + SequencePlan( + has_text=sp.has_text, + has_vision=sp.has_vision, + condition_frame_indexes_vision=sp.condition_frame_indexes_vision, + share_vision_temporal_positions=False, + has_action=sp.has_action, + condition_frame_indexes_action=sp.condition_frame_indexes_action, + action_start_frame_offset=sp.action_start_frame_offset, + has_sound=sp.has_sound, + condition_frame_indexes_sound=sp.condition_frame_indexes_sound, + ) + for sp in sequence_plans + ] + + return sp_nc, gdc_nc, ctrl_dims_per_sample + @torch.no_grad() def generate_samples_from_batch( self, @@ -2113,6 +2183,8 @@ def generate_samples_from_batch( sampler: Any | None = None, guidance: float = 1.5, guidance_interval: Optional[list[float]] = None, + control_guidance: float = 1.0, + control_guidance_interval: Optional[list[float]] = None, seed: list[int] | int = 1, n_sample: int | None = None, has_negative_prompt: bool = False, @@ -2152,6 +2224,11 @@ def generate_samples_from_batch( guidance (float): Classifier-free guidance weight. guidance_interval (list[float] | None): Optional timestep interval to apply guidance. For the timesteps (ranging between 0-1000) that fall between the interval, we perform CFG, otherwise, we skip the unconditional generation. + control_guidance (float): Control-CFG scale for transfer inference. ``1.0`` (default) + disables the extra comparison forward; values ``> 1.0`` blend velocities from + with-control-map vs without-control-map forwards on the generated clip. + control_guidance_interval (list[float] | None): Optional timestep interval to apply + control-CFG; ``None`` applies on every step. seed (list[int] | int): Random seeds for noise generation. For all new use-cases, we use a list of seeds, one for each sample. The length of the list must match the number of samples. Legacy use-cases use a single integer seed which is @@ -2277,6 +2354,15 @@ def generate_samples_from_batch( assert n_sample == len(seed), f"Number of samples {n_sample} must match number of seeds {len(seed)}" + no_control_state = None + if control_guidance != 1.0: + no_control_state = self._build_no_control_inference_state(sequence_plans, gen_data_clean) + if no_control_state is None: + log.warning( + "control_guidance != 1.0 but no multi-vision sample found; " + "control-CFG disabled (single-branch inference)." + ) + # FSDP collective-sequence alignment (throughput-style inference). Each # FSDP-shard rank holds a different sample, and ``velocity_fn`` issues 1 # model forward when this rank skips CFG (guidance == 1.0, or a timestep @@ -2327,43 +2413,94 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): ) # Local CFG decision for THIS rank, honoring guidance_interval. - _local_needs_cfg = guidance != 1.0 - if _local_needs_cfg and guidance_interval is not None: + _local_needs_text_cfg = guidance != 1.0 + if _local_needs_text_cfg and guidance_interval is not None: assert len(guidance_interval) == 2, f"guidance_interval must be [lo, hi], got {guidance_interval}" t_lo, t_hi = guidance_interval - _local_needs_cfg = t_lo < timestep[0].item() < t_hi + _local_needs_text_cfg = t_lo < timestep[0].item() < t_hi + + _local_needs_control_cfg = no_control_state is not None + if _local_needs_control_cfg and control_guidance_interval is not None: + assert len(control_guidance_interval) == 2, ( + f"control_guidance_interval must be [lo, hi], got {control_guidance_interval}" + ) + t_lo_c, t_hi_c = control_guidance_interval + _local_needs_control_cfg = t_lo_c < timestep[0].item() < t_hi_c - # FSDP alignment: if ANY rank in the shard group needs CFG this call, - # every rank computes both forwards (cheap 1-element all_reduce per + # FSDP alignment: if ANY rank in the shard group needs CFG or control-CFG this call, + # every rank computes the matching forwards (cheap 1-element all_reduce per # velocity_fn call). Forcing CFG always-on globally would instead # silently ignore the per-timestep guidance_interval gate. if _dp_shard_group is not None: - _cfg_t = torch.tensor([1 if _local_needs_cfg else 0], device=_align_device, dtype=torch.int32) + _cfg_t = torch.tensor( + [1 if _local_needs_text_cfg else 0], device=_align_device, dtype=torch.int32 + ) torch.distributed.all_reduce(_cfg_t, op=torch.distributed.ReduceOp.MAX, group=_dp_shard_group) - _any_needs_cfg = bool(_cfg_t.item()) + _any_needs_text_cfg = bool(_cfg_t.item()) + _ctrl_t = torch.tensor( + [1 if _local_needs_control_cfg else 0], device=_align_device, dtype=torch.int32 + ) + torch.distributed.all_reduce(_ctrl_t, op=torch.distributed.ReduceOp.MAX, group=_dp_shard_group) + _any_needs_control_cfg = bool(_ctrl_t.item()) else: - _any_needs_cfg = _local_needs_cfg + _any_needs_text_cfg = _local_needs_text_cfg + _any_needs_control_cfg = _local_needs_control_cfg - if not _any_needs_cfg: + if not _any_needs_text_cfg and not _any_needs_control_cfg: return _single_velocity_fn(cond_tokens, skip_text_tokens=False) - # Both forwards happen — needed for FSDP collective alignment - # across ranks even if THIS rank's local decision was "no CFG". - cond_v, uncond_v = self._run_classifier_free_guidance( - cond_tokens=cond_tokens, - uncond_tokens=uncond_tokens, - skip_text_tokens_for_cfg=skip_text_tokens_for_cfg, - single_velocity_fn=_single_velocity_fn, - ) + if _any_needs_control_cfg: + cond_v_full = _single_velocity_fn(cond_tokens, skip_text_tokens=False) + sp_nc, gdc_nc, ctrl_dims = no_control_state # type: ignore[misc] + noise_x_nc = [nx[ctrl_dim:] for nx, ctrl_dim in zip(noise_x, ctrl_dims)] + cond_v_nc = self._get_velocity( + net=net, + noise_x=noise_x_nc, + timestep=timestep, + text_tokens=cond_tokens, + sequence_plans=sp_nc, + gen_data_clean=gdc_nc, + skip_text_tokens=False, + ) + if _local_needs_control_cfg: + cond_v = [] + for v_full_i, v_nc_i, ctrl_dim_i in zip(cond_v_full, cond_v_nc, ctrl_dims): + suffix_full = v_full_i[ctrl_dim_i:] + assert suffix_full.shape == v_nc_i.shape, ( + f"shape mismatch in control-CFG mix: full suffix {suffix_full.shape} " + f"vs no-control {v_nc_i.shape}" + ) + mixed_suffix = v_nc_i + control_guidance * (suffix_full - v_nc_i) + cond_v.append(torch.cat([v_full_i[:ctrl_dim_i], mixed_suffix], dim=0)) + else: + cond_v = cond_v_full + + if not _any_needs_text_cfg: + return cond_v + + uncond_v = _single_velocity_fn(uncond_tokens, skip_text_tokens=skip_text_tokens_for_cfg) + if not _local_needs_text_cfg: + return cond_v + + v_pred = [u_i + guidance * (c_i - u_i) for c_i, u_i in zip(cond_v, uncond_v)] + else: + # Both forwards happen — needed for FSDP collective alignment + # across ranks even if THIS rank's local decision was "no CFG". + cond_v, uncond_v = self._run_classifier_free_guidance( + cond_tokens=cond_tokens, + uncond_tokens=uncond_tokens, + skip_text_tokens_for_cfg=skip_text_tokens_for_cfg, + single_velocity_fn=_single_velocity_fn, + ) - if not _local_needs_cfg: - # This rank didn't actually need CFG (guidance==1.0, or sigma - # outside guidance_interval). Return cond_v directly so the output - # is bit-identical to the no-CFG path; the uncond_v forward ran - # only to keep the FSDP all-gather sequence aligned with peers. - return cond_v + if not _local_needs_text_cfg: + # This rank didn't actually need CFG (guidance==1.0, or sigma + # outside guidance_interval). Return cond_v directly so the output + # is bit-identical to the no-CFG path; the uncond_v forward ran + # only to keep the FSDP all-gather sequence aligned with peers. + return cond_v - v_pred = [u_i + guidance * (c_i - u_i) for c_i, u_i in zip(cond_v, uncond_v)] + v_pred = [u_i + guidance * (c_i - u_i) for c_i, u_i in zip(cond_v, uncond_v)] if normalize_cfg: v_pred = [