From 5bd1bceccf7f32afddec2463b4428905c5aa8afa Mon Sep 17 00:00:00 2001 From: black-yt <2667685528@qq.com> Date: Tue, 17 Mar 2026 10:32:54 +0800 Subject: [PATCH] feat: add SiT/DiT class-conditional image generation support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add SiTDiTAdapter supporting SiT and DiT model variants (model_type: "sit" or "dit", both use the same adapter) - Class conditioning via integer label strings ("0"–"999" for ImageNet) - Supports standard SD-VAE and ReaLS VAE latent normalization - Timestep conversion: t_sit = 1 - t_ff/1000 - Sign convention: noise_pred = -v_sit for FlowMatchEulerDiscreteSDEScheduler - CFG via null label index (= num_classes) - Add example config: examples/grpo/lora/sit_xl2.yaml - Update README: supported models table + dataset format docs --- README.md | 21 + examples/grpo/lora/sit_xl2.yaml | 107 ++++ src/flow_factory/models/registry.py | 2 + src/flow_factory/models/sit_dit/__init__.py | 3 + src/flow_factory/models/sit_dit/models.py | 299 ++++++++++ src/flow_factory/models/sit_dit/sit_dit.py | 591 ++++++++++++++++++++ 6 files changed, 1023 insertions(+) create mode 100644 examples/grpo/lora/sit_xl2.yaml create mode 100644 src/flow_factory/models/sit_dit/__init__.py create mode 100644 src/flow_factory/models/sit_dit/models.py create mode 100644 src/flow_factory/models/sit_dit/sit_dit.py diff --git a/README.md b/README.md index cfc47cb3..f0668b1b 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,11 @@ This experimental feature leverages `diffusers`'s `transformer.set_attention_bac + + + + + @@ -158,6 +163,22 @@ The unified structure of dataset is: For text-to-image and text-to-video tasks, the only required input is the **prompt** in plain text format. Use `train.txt` and `test.txt` (optional) with following format: +## Class-to-Image (SiT / DiT) + +For class-conditional generation with SiT or DiT, each prompt is an **integer class index string** (e.g., ImageNet class label). Use `train.jsonl` with the following format: + +```jsonl +{"prompt": "985"} +{"prompt": "207"} +{"prompt": "388"} +``` + +> Class indices follow the standard ImageNet-1K ordering (0–999). Example: 985 = daisy, 207 = golden retriever, 388 = giant panda. + +The `model_name_or_path` directory must contain a `config.json` describing the model variant and VAE path. See [`examples/grpo/lora/sit_xl2.yaml`](examples/grpo/lora/sit_xl2.yaml) for a complete configuration example. + +## Text-to-Image & Text-to-Video (original) + ``` A hill in a sunset. An astronaut riding a horse on Mars. diff --git a/examples/grpo/lora/sit_xl2.yaml b/examples/grpo/lora/sit_xl2.yaml new file mode 100644 index 00000000..c6bfaed2 --- /dev/null +++ b/examples/grpo/lora/sit_xl2.yaml @@ -0,0 +1,107 @@ +# Environment Configuration +launcher: "accelerate" +config_file: config/deepspeed/deepspeed_zero2.yaml +num_processes: 8 +main_process_port: 29500 +mixed_precision: "bf16" + +run_name: null +project: "Flow-Factory" +logging_backend: "wandb" + +# Data Configuration +# Dataset items should have "prompt" field containing the integer class label string, +# e.g. {"prompt": "985"} for ImageNet class 985 (daisy). +data: + dataset_dir: "dataset/imagenet_cls" + preprocessing_batch_size: 8 + dataloader_num_workers: 16 + force_reprocess: true + cache_dir: "~/.cache/flow_factory/datasets" + max_dataset_size: 1000 + +# Model Configuration +model: + finetune_type: 'lora' + lora_rank: 64 + lora_alpha: 128 + target_modules: "default" + # model_name_or_path must be a local directory containing: + # config.json — model config (see SiTDiTAdapter docstring) + # model.safetensors or pytorch_model.bin — transformer weights + model_name_or_path: "/path/to/sit_xl2_imagenet" + model_type: "sit" # or "dit" (both use the same SiTDiTAdapter) + target_components: ["transformer"] + resume_path: null + resume_type: null + +log: + save_dir: "~/Flow-Factory" + save_freq: 20 + save_model_only: true + +# Training Configuration +train: + trainer_type: 'grpo' + advantage_aggregation: 'gdpo' + clip_range: 1.0e-4 + adv_clip_range: 5.0 + kl_type: 'v-based' + kl_beta: 0 + ref_param_device: 'cuda' + + resolution: 256 # SiT-XL/2 default resolution (256x256 images) + num_inference_steps: 20 + guidance_scale: 4.0 # class-conditional CFG scale (typical: 1.5–5.0) + + per_device_batch_size: 2 + group_size: 16 + global_std: false + unique_sample_num_per_epoch: 48 + gradient_step_per_epoch: 2 + + learning_rate: 3.0e-4 + adam_weight_decay: 1.0e-4 + adam_betas: [0.9, 0.999] + adam_epsilon: 1.0e-8 + max_grad_norm: 1.0 + + ema_decay: 0.9 + ema_update_interval: 4 + ema_device: "cuda" + + enable_gradient_checkpointing: false + seed: 42 + +# Scheduler Configuration +scheduler: + dynamics_type: "Flow-SDE" + noise_level: 0.7 + num_sde_steps: 1 + sde_steps: [1, 2, 3] + seed: 42 + +# Evaluation settings +eval: + resolution: 256 + per_device_batch_size: 1 + guidance_scale: 4.0 + num_inference_steps: 50 + eval_freq: 20 + seed: 42 + +# Reward Model Configuration +rewards: + - name: "pickscore_rank" + reward_model: "PickScore_Rank" + weight: 1.0 + batch_size: 16 + device: "cuda" + dtype: bfloat16 + +eval_rewards: + name: "pickscore" + reward_model: "PickScore" + batch_size: 16 + device: "cuda" + dtype: bfloat16 diff --git a/src/flow_factory/models/registry.py b/src/flow_factory/models/registry.py index 339499a5..0f78ade9 100644 --- a/src/flow_factory/models/registry.py +++ b/src/flow_factory/models/registry.py @@ -38,6 +38,8 @@ 'wan2_i2v': 'flow_factory.models.wan.wan2_i2v.Wan2_I2V_Adapter', 'wan2_t2v': 'flow_factory.models.wan.wan2_t2v.Wan2_T2V_Adapter', 'wan2_v2v': 'flow_factory.models.wan.wan2_v2v.Wan2_V2V_Adapter', + 'sit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter', + 'dit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter', } def get_model_adapter_class(identifier: str) -> Type: diff --git a/src/flow_factory/models/sit_dit/__init__.py b/src/flow_factory/models/sit_dit/__init__.py new file mode 100644 index 00000000..e86c0cdd --- /dev/null +++ b/src/flow_factory/models/sit_dit/__init__.py @@ -0,0 +1,3 @@ +from .sit_dit import SiTDiTAdapter, SiTDiTSample + +__all__ = ["SiTDiTAdapter", "SiTDiTSample"] diff --git a/src/flow_factory/models/sit_dit/models.py b/src/flow_factory/models/sit_dit/models.py new file mode 100644 index 00000000..841e0fea --- /dev/null +++ b/src/flow_factory/models/sit_dit/models.py @@ -0,0 +1,299 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# SiT: https://github.com/willisma/SiT +# -------------------------------------------------------- + +import math +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +# --------------------------------------------------------------------------- +# Embedding layers +# --------------------------------------------------------------------------- + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: torch.Tensor) -> torch.Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(t_freq) + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + Handles label dropout for classifier-free guidance. + The null label index is num_classes (the extra embedding row). + """ + + def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels: torch.Tensor, force_drop_ids=None) -> torch.Tensor: + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.Tensor, train: bool, force_drop_ids=None) -> torch.Tensor: + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + return self.embedding_table(labels) + + +# --------------------------------------------------------------------------- +# SiT / DiT blocks +# --------------------------------------------------------------------------- + +class SiTBlock(nn.Module): + """SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.""" + + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """The final layer of SiT.""" + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + return self.linear(x) + + +# --------------------------------------------------------------------------- +# SiT / DiT model +# --------------------------------------------------------------------------- + +class SiT(nn.Module): + """ + Scalable Interpolant Transformer (SiT) / Diffusion Transformer (DiT). + + Both architectures share this backbone; the difference lies in the + training objective (flow-matching velocity for SiT, noise prediction + for DiT) which is handled outside the model. + """ + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + hidden_size: int = 1152, + depth: int = 28, + num_heads: int = 16, + mlp_ratio: float = 4.0, + class_dropout_prob: float = 0.1, + num_classes: int = 1000, + learn_sigma: bool = True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], + int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + x = x.reshape(x.shape[0], h, w, p, p, c) + x = torch.einsum('nhwpqc->nchpwq', x) + return x.reshape(x.shape[0], c, h * p, h * p) + + def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C, H, W) noisy latents + t: (B,) timesteps in [0, 1] + y: (B,) integer class labels + Returns: + (B, C, H, W) predicted velocity (SiT) or noise (DiT) + """ + x = self.x_embedder(x) + self.pos_embed + t = self.t_embedder(t) + y = self.y_embedder(y, self.training) + c = t + y + for block in self.blocks: + x = block(x, c) + x = self.final_layer(x, c) + x = self.unpatchify(x) + if self.learn_sigma: + x, _ = x.chunk(2, dim=1) + return x + + +# --------------------------------------------------------------------------- +# Positional embedding helpers +# --------------------------------------------------------------------------- + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, + cls_token: bool = False, extra_tokens: int = 0) -> np.ndarray: + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega + pos = pos.reshape(-1) + out = np.einsum('m,d->md', pos, omega) + emb = np.concatenate([np.sin(out), np.cos(out)], axis=1) + return emb + + +# --------------------------------------------------------------------------- +# Model configs +# --------------------------------------------------------------------------- + +def SiT_XL_2(**kw): return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kw) +def SiT_XL_4(**kw): return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kw) +def SiT_XL_8(**kw): return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kw) +def SiT_L_2(**kw): return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kw) +def SiT_L_4(**kw): return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kw) +def SiT_L_8(**kw): return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kw) +def SiT_B_2(**kw): return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kw) +def SiT_B_4(**kw): return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kw) +def SiT_B_8(**kw): return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kw) +def SiT_S_2(**kw): return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kw) +def SiT_S_4(**kw): return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kw) +def SiT_S_8(**kw): return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kw) + +SiT_models: Dict[str, type] = { + 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8, + 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8, + 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8, + 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8, + # DiT aliases (same architecture, different training objective) + 'DiT-XL/2': SiT_XL_2, 'DiT-XL/4': SiT_XL_4, 'DiT-XL/8': SiT_XL_8, + 'DiT-L/2': SiT_L_2, 'DiT-L/4': SiT_L_4, 'DiT-L/8': SiT_L_8, + 'DiT-B/2': SiT_B_2, 'DiT-B/4': SiT_B_4, 'DiT-B/8': SiT_B_8, + 'DiT-S/2': SiT_S_2, 'DiT-S/4': SiT_S_4, 'DiT-S/8': SiT_S_8, +} diff --git a/src/flow_factory/models/sit_dit/sit_dit.py b/src/flow_factory/models/sit_dit/sit_dit.py new file mode 100644 index 00000000..a2adaf66 --- /dev/null +++ b/src/flow_factory/models/sit_dit/sit_dit.py @@ -0,0 +1,591 @@ +# Copyright 2026 Jayce-Ping +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# src/flow_factory/models/sit_dit/sit_dit.py +""" +SiT / DiT adapter for Flow-Factory. + +Architecture: class-conditional image generation on 256x256 (ImageNet-style). +The transformer is the SiT/DiT backbone; the VAE is an SD-VAE compatible encoder. + +Timestep conventions +-------------------- +- Flow-Factory scheduler: t ∈ [0, 1000], t=1000 ↔ pure noise +- SiT model: t ∈ [0, 1], t=0 ↔ noise, t=1 ↔ data +- Conversion: t_sit = 1 - t_ff / 1000 + +Sign convention for scheduler +------------------------------ +Flow-Factory: next = x + noise_pred * dt, dt = sigma_next - sigma_curr < 0 +SiT velocity points toward data → noise_pred = -v_sit + +CFG for class conditioning +-------------------------- + v_pred = v_uncond + cfg_scale * (v_cond - v_uncond) + null label index = num_classes (extra embedding row) +""" +from __future__ import annotations + +import os +import json +from typing import Union, List, Dict, Any, Optional, Tuple, Literal, ClassVar + +import numpy as np +from dataclasses import dataclass +from PIL import Image + +import torch +import torch.nn as nn +from diffusers import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from accelerate import Accelerator + +from ...samples import T2ISample +from ..abc import BaseAdapter +from ...hparams import * +from ...scheduler import ( + FlowMatchEulerDiscreteSDEScheduler, + FlowMatchEulerDiscreteSDESchedulerOutput, + SDESchedulerOutput, + set_scheduler_timesteps, +) +from ...utils.base import filter_kwargs +from ...utils.trajectory_collector import ( + TrajectoryCollector, + CallbackCollector, + TrajectoryIndicesType, + create_trajectory_collector, + create_callback_collector, +) +from ...utils.logger_utils import setup_logger +from .models import SiT_models + +logger = setup_logger(__name__) + +# SD-VAE scale factor (same as Stable Diffusion) +_SCALE_FACTOR = 0.18215 + + +# --------------------------------------------------------------------------- +# Pipeline container +# --------------------------------------------------------------------------- + +class SiTDiTPipeline: + """ + Minimal pipeline container for SiT / DiT models. + + Attributes + ---------- + transformer : SiT + The SiT / DiT transformer. + vae : AutoencoderKL + Diffusers VAE for encoding / decoding latents. + scheduler : FlowMatchEulerDiscreteScheduler + Initial scheduler (replaced by the SDE variant during BaseAdapter init). + new_mean : Optional[float] + ReaLS latent normalization mean. None for standard SD-VAE latents. + new_std : Optional[float] + ReaLS latent normalization std. None for standard SD-VAE latents. + """ + + def __init__( + self, + transformer: nn.Module, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + new_mean: Optional[float] = None, + new_std: Optional[float] = None, + ): + self.transformer = transformer + self.vae = vae + self.scheduler = scheduler + self.new_mean = new_mean + self.new_std = new_std + + def maybe_free_model_hooks(self): + """No-op — kept for API compatibility with BaseAdapter.""" + pass + + +# --------------------------------------------------------------------------- +# Sample dataclass +# --------------------------------------------------------------------------- + +@dataclass +class SiTDiTSample(T2ISample): + """Output sample for SiT / DiT generation.""" + # Instance variables + class_labels: Optional[torch.Tensor] = None # () scalar integer class label (per sample) + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- + +class SiTDiTAdapter(BaseAdapter): + """ + Flow-Factory adapter for SiT (Scalable Interpolant Transformer) and + DiT (Diffusion Transformer) class-conditional image generation. + + Config JSON format (``model_name_or_path/config.json``) + -------------------------------------------------------- + { + "model_name": "SiT-XL/2", // key in SiT_models dict + "num_classes": 1000, // number of classes (ImageNet: 1000) + "learn_sigma": false, + "input_size": 32, // spatial size of latent (image/8) + "in_channels": 4, // VAE latent channels + "vae_path": "stabilityai/sd-vae-ft-ema", // HF id or local path + "reals_vae": false, // true → use mean_std.json normalization + "mean_std_path": null // path to mean_std.json (if reals_vae=true) + } + + Weights are loaded from ``model_name_or_path/pytorch_model.bin`` or + ``model_name_or_path/model.safetensors``. + """ + + def __init__(self, config: Arguments, accelerator: Accelerator): + super().__init__(config, accelerator) + self.pipeline: SiTDiTPipeline + self.scheduler: FlowMatchEulerDiscreteSDEScheduler + + # ========================== Pipeline Loading ========================== + + def load_pipeline(self) -> SiTDiTPipeline: + path = self.model_args.model_name_or_path + + # -- Load model config ------------------------------------------------ + config_path = os.path.join(path, "config.json") + with open(config_path, "r") as f: + cfg = json.load(f) + + model_name = cfg.get("model_name", "SiT-XL/2") + num_classes = cfg.get("num_classes", 1000) + learn_sigma = cfg.get("learn_sigma", False) + input_size = cfg.get("input_size", 32) + in_channels = cfg.get("in_channels", 4) + vae_path = cfg.get("vae_path", "stabilityai/sd-vae-ft-ema") + reals_vae = cfg.get("reals_vae", False) + mean_std_path = cfg.get("mean_std_path", None) + + # Cache for decode_latents + self._num_classes = num_classes + self._reals_vae = reals_vae + + # -- Load SiT / DiT transformer --------------------------------------- + if model_name not in SiT_models: + raise ValueError( + f"Unknown model_name '{model_name}'. " + f"Available: {list(SiT_models.keys())}" + ) + transformer = SiT_models[model_name]( + input_size=input_size, + in_channels=in_channels, + num_classes=num_classes, + learn_sigma=learn_sigma, + ) + + # Try safetensors first, then pytorch_model.bin + weights_path = None + for fname in ("model.safetensors", "pytorch_model.bin"): + candidate = os.path.join(path, fname) + if os.path.isfile(candidate): + weights_path = candidate + break + + if weights_path is not None: + if weights_path.endswith(".safetensors"): + from safetensors.torch import load_file as load_safetensors + state_dict = load_safetensors(weights_path) + else: + state_dict = torch.load(weights_path, map_location="cpu") + + # Support bare state dict or nested {"ema": ..., "model": ...} + if isinstance(state_dict, dict): + if "ema" in state_dict: + state_dict = state_dict["ema"] + elif "model" in state_dict: + state_dict = state_dict["model"] + + missing, unexpected = transformer.load_state_dict(state_dict, strict=True) + if missing: + logger.warning(f"Missing keys when loading transformer: {missing[:5]}...") + if unexpected: + logger.warning(f"Unexpected keys when loading transformer: {unexpected[:5]}...") + logger.info(f"Loaded transformer weights from {weights_path}") + else: + logger.warning( + f"No transformer weights found in {path}. " + "Starting from random initialization." + ) + + # -- Load VAE --------------------------------------------------------- + vae = AutoencoderKL.from_pretrained(vae_path) + + # -- Load ReaLS latent normalization stats ---------------------------- + new_mean = None + new_std = None + if reals_vae: + if mean_std_path is None: + mean_std_path = os.path.join(path, "mean_std.json") + if os.path.isfile(mean_std_path): + with open(mean_std_path, "r") as f: + stats = json.load(f) + new_mean = float(stats["mean"]) + new_std = float(stats["std"]) + logger.info( + f"Loaded ReaLS latent stats: mean={new_mean:.4f}, std={new_std:.4f}" + ) + else: + raise FileNotFoundError( + f"reals_vae=True but mean_std.json not found at {mean_std_path}. " + "Set 'mean_std_path' in config.json or place mean_std.json in " + "model_name_or_path/." + ) + + # -- Build scheduler (placeholder; replaced by load_scheduler) -------- + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=1.0, + ) + + return SiTDiTPipeline( + transformer=transformer, + vae=vae, + scheduler=scheduler, + new_mean=new_mean, + new_std=new_std, + ) + + # ========================== Default Modules ========================== + + @property + def default_target_modules(self) -> List[str]: + """Default trainable modules for SiT / DiT (attention projections).""" + return [ + "attn.to_q", "attn.to_k", "attn.to_v", "attn.to_out.0", + "attn.qkv", + ] + + # ======================== Encoding & Decoding ======================== + + def encode_prompt( + self, + prompt: Union[str, List[str]], + **kwargs, + ) -> Dict[str, Any]: + """ + Parse class label(s) from string prompt(s). + + Each prompt is interpreted as an integer class index (e.g. "0" or "985"). + Returns dummy ``prompt_embeds`` / ``prompt_ids`` tensors so that the + BaseAdapter machinery that expects these fields does not crash. + """ + if isinstance(prompt, str): + prompt = [prompt] + + device = self.device + labels = [] + for p in prompt: + try: + labels.append(int(p.strip())) + except ValueError: + logger.warning( + f"Could not parse class label from prompt '{p}'. Using 0." + ) + labels.append(0) + + class_labels = torch.tensor(labels, dtype=torch.long, device=device) + batch_size = len(labels) + + # Dummy tensors — kept for API compatibility; not used by forward() + dummy_embeds = torch.zeros(batch_size, 1, dtype=torch.float32, device=device) + dummy_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + + return { + "class_labels": class_labels, + "prompt_embeds": dummy_embeds, + "prompt_ids": dummy_ids, + } + + def encode_image(self, images) -> None: + """Not used for class-conditional generation.""" + pass + + def encode_video(self, videos) -> None: + """Not used for SiT / DiT.""" + pass + + def decode_latents( + self, + latents: torch.Tensor, + height: int, + width: int, + output_type: Literal["pil", "pt", "np"] = "pil", + ) -> Union[List[Image.Image], torch.Tensor, np.ndarray]: + """Decode latents to pixel images using the VAE.""" + vae = self.pipeline.vae + dtype = vae.dtype if hasattr(vae, 'dtype') else latents.dtype + latents = latents.to(dtype=dtype) + + if self._reals_vae and self.pipeline.new_mean is not None: + # ReaLS latent space: denormalize to raw VAE latents + new_mean = self.pipeline.new_mean + new_std = self.pipeline.new_std + z_vae = latents * new_std + new_mean + else: + # Standard SD-VAE: latents are scaled by SCALE_FACTOR during training + z_vae = latents / _SCALE_FACTOR + + images = vae.decode(z_vae, return_dict=False)[0] + + # Map from [-1, 1] to [0, 1] + images = (images / 2 + 0.5).clamp(0, 1) + + if output_type == "pt": + return images # (B, C, H, W) + + images_np = images.permute(0, 2, 3, 1).float().cpu().numpy() + if output_type == "np": + return images_np + + # PIL output + images_uint8 = (images_np * 255).round().astype(np.uint8) + return [Image.fromarray(img) for img in images_uint8] + + # ======================== Inference ======================== + + @torch.no_grad() + def inference( + self, + # Ordinary args + prompt: Optional[Union[str, List[str]]] = None, + height: int = 256, + width: int = 256, + num_inference_steps: int = 50, + guidance_scale: float = 4.0, + generator: Optional[torch.Generator] = None, + # Pre-encoded prompt + prompt_ids: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + # Other args + compute_log_prob: bool = True, + extra_call_back_kwargs: List[str] = [], + trajectory_indices: TrajectoryIndicesType = "all", + ) -> List[SiTDiTSample]: + """Execute generation and return SiTDiTSample objects.""" + + device = self.device + transformer = self.transformer + dtype = self._inference_dtype + + # 1. Encode prompts if not provided + if class_labels is None: + encoded = self.encode_prompt(prompt) + class_labels = encoded["class_labels"] + prompt_ids = encoded["prompt_ids"] + prompt_embeds = encoded["prompt_embeds"] + else: + class_labels = class_labels.to(device) + + batch_size = class_labels.shape[0] + + # 2. Prepare latents + vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1) + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + in_channels = transformer.in_channels + + latents = torch.randn( + batch_size, in_channels, latent_height, latent_width, + dtype=dtype, device=device, generator=generator, + ) + + # 3. Set timesteps + timesteps = set_scheduler_timesteps( + scheduler=self.pipeline.scheduler, + num_inference_steps=num_inference_steps, + seq_len=latent_height * latent_width, # spatial token count (H*W) + device=device, + ) + + # 4. Denoising loop + latent_collector = create_trajectory_collector(trajectory_indices, num_inference_steps) + latents = self.cast_latents(latents, default_dtype=dtype) + latent_collector.collect(latents, step_idx=0) + if compute_log_prob: + log_prob_collector = create_trajectory_collector(trajectory_indices, num_inference_steps) + callback_collector = create_callback_collector(trajectory_indices, num_inference_steps) + + for i, t in enumerate(timesteps): + current_noise_level = self.scheduler.get_noise_level_for_timestep(t) + t_next = ( + timesteps[i + 1] + if i + 1 < len(timesteps) + else torch.tensor(0, device=device) + ) + return_kwargs = list( + set(["next_latents", "log_prob", "noise_pred"] + extra_call_back_kwargs) + ) + current_compute_log_prob = compute_log_prob and current_noise_level > 0 + + output = self.forward( + t=t, + t_next=t_next, + latents=latents, + class_labels=class_labels, + guidance_scale=guidance_scale, + compute_log_prob=current_compute_log_prob, + return_kwargs=return_kwargs, + noise_level=current_noise_level, + ) + + latents = self.cast_latents(output.next_latents, default_dtype=dtype) + latent_collector.collect(latents, i + 1) + if current_compute_log_prob: + log_prob_collector.collect(output.log_prob, i) + + callback_collector.collect_step( + step_idx=i, + output=output, + keys=extra_call_back_kwargs, + capturable={"noise_level": current_noise_level}, + ) + + # 5. Decode images + images = self.decode_latents(latents, height, width, output_type="pt") + + # 6. Build sample list + extra_call_back_res = callback_collector.get_result() + callback_index_map = callback_collector.get_index_map() + all_latents = latent_collector.get_result() + latent_index_map = latent_collector.get_index_map() + all_log_probs = log_prob_collector.get_result() if compute_log_prob else None + log_prob_index_map = log_prob_collector.get_index_map() if compute_log_prob else None + + samples = [ + SiTDiTSample( + # Denoising trajectory + timesteps=timesteps, + all_latents=( + torch.stack([lat[b] for lat in all_latents], dim=0) + if all_latents else None + ), + log_probs=( + torch.stack([lp[b] for lp in all_log_probs], dim=0) + if all_log_probs else None + ), + latent_index_map=latent_index_map, + log_prob_index_map=log_prob_index_map, + # Prompt / class label + prompt=prompt[b] if isinstance(prompt, list) else prompt, + prompt_ids=prompt_ids[b] if prompt_ids is not None else None, + prompt_embeds=prompt_embeds[b] if prompt_embeds is not None else None, + class_labels=class_labels[b], # scalar per-sample label + # Image & metadata + height=height, + width=width, + image=images[b], + # Extra kwargs + extra_kwargs={ + **{k: v[b] for k, v in extra_call_back_res.items()}, + "callback_index_map": callback_index_map, + }, + ) + for b in range(batch_size) + ] + + self.pipeline.maybe_free_model_hooks() + return samples + + # ======================== Forward (Training) ======================== + + def forward( + self, + t: torch.Tensor, + latents: torch.Tensor, + # Class conditioning (used instead of prompt_embeds) + class_labels: torch.Tensor, + # Next timestep + t_next: Optional[torch.Tensor] = None, + next_latents: Optional[torch.Tensor] = None, + # CFG + guidance_scale: Union[float, List[float]] = 4.0, + noise_level: Optional[float] = None, + # Other + compute_log_prob: bool = True, + return_kwargs: List[str] = [ + "noise_pred", "next_latents", "next_latents_mean", "std_dev_t", "dt", "log_prob" + ], + ) -> FlowMatchEulerDiscreteSDESchedulerOutput: + """ + One denoising step. + + Converts the Flow-Factory timestep to SiT convention, runs the + transformer with optional CFG, negates the velocity to match the + Flow-Factory scheduler sign convention, then calls scheduler.step(). + """ + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + transformer = self.transformer + + # ------------------------------------------------------------------ + # 1. Timestep conversion: t_ff ∈ [0,1000] → t_sit ∈ [0,1] + # ------------------------------------------------------------------ + sigma_t = t.float() / 1000.0 # sigma: 1 = pure noise + t_sit = 1.0 - sigma_t # SiT: 0 = noise, 1 = data + t_sit_batch = t_sit.expand(batch_size).to(dtype=dtype, device=device) + + # ------------------------------------------------------------------ + # 2. Transformer forward pass (with optional CFG) + # ------------------------------------------------------------------ + do_cfg = (guidance_scale > 1.0) and (class_labels is not None) + + if do_cfg: + # Concatenate conditional and unconditional batches + null_labels = torch.full_like(class_labels, transformer.y_embedder.num_classes) + labels_double = torch.cat([class_labels, null_labels], dim=0) + latents_double = torch.cat([latents, latents], dim=0) + t_sit_double = t_sit_batch.repeat(2) + + v_double = transformer(latents_double, t_sit_double, labels_double) + v_cond, v_uncond = v_double.chunk(2, dim=0) + v_pred = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + v_pred = transformer(latents, t_sit_batch, class_labels) + + # ------------------------------------------------------------------ + # 3. Sign flip: Flow-Factory uses next = x + noise_pred * dt + # with dt < 0 (sigma decreasing). + # SiT velocity points toward data, so we negate. + # ------------------------------------------------------------------ + noise_pred = -v_pred + + # ------------------------------------------------------------------ + # 4. Scheduler step + # ------------------------------------------------------------------ + output = self.scheduler.step( + noise_pred=noise_pred, + timestep=t, + latents=latents, + timestep_next=t_next, + next_latents=next_latents, + compute_log_prob=compute_log_prob, + return_dict=True, + return_kwargs=return_kwargs, + noise_level=noise_level, + ) + return output
TaskModelModel SizeModel Type
Class-to-ImageSiT-XL/2 (ImageNet-256)675Msit
SiT-L/2, SiT-B/2, SiT-S/2458M/130M/33Msit
DiT-XL/2 (ImageNet-256)675Mdit
DiT-L/2, DiT-B/2, DiT-S/2458M/130M/33Mdit
Text-to-Imagestable-diffusion-3.5-medium/large2.5B/8.1Bsd3-5
FLUX.1-dev13Bflux1
Z-Image-Turbo6Bz-image