From 5bd1bceccf7f32afddec2463b4428905c5aa8afa Mon Sep 17 00:00:00 2001
From: black-yt <2667685528@qq.com>
Date: Tue, 17 Mar 2026 10:32:54 +0800
Subject: [PATCH] feat: add SiT/DiT class-conditional image generation
support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Add SiTDiTAdapter supporting SiT and DiT model variants
(model_type: "sit" or "dit", both use the same adapter)
- Class conditioning via integer label strings ("0"–"999" for ImageNet)
- Supports standard SD-VAE and ReaLS VAE latent normalization
- Timestep conversion: t_sit = 1 - t_ff/1000
- Sign convention: noise_pred = -v_sit for FlowMatchEulerDiscreteSDEScheduler
- CFG via null label index (= num_classes)
- Add example config: examples/grpo/lora/sit_xl2.yaml
- Update README: supported models table + dataset format docs
---
README.md | 21 +
examples/grpo/lora/sit_xl2.yaml | 107 ++++
src/flow_factory/models/registry.py | 2 +
src/flow_factory/models/sit_dit/__init__.py | 3 +
src/flow_factory/models/sit_dit/models.py | 299 ++++++++++
src/flow_factory/models/sit_dit/sit_dit.py | 591 ++++++++++++++++++++
6 files changed, 1023 insertions(+)
create mode 100644 examples/grpo/lora/sit_xl2.yaml
create mode 100644 src/flow_factory/models/sit_dit/__init__.py
create mode 100644 src/flow_factory/models/sit_dit/models.py
create mode 100644 src/flow_factory/models/sit_dit/sit_dit.py
diff --git a/README.md b/README.md
index cfc47cb3..f0668b1b 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,11 @@ This experimental feature leverages `diffusers`'s `transformer.set_attention_bac
| Task | Model | Model Size | Model Type |
+ | Class-to-Image | SiT-XL/2 (ImageNet-256) | 675M | sit |
+ | SiT-L/2, SiT-B/2, SiT-S/2 | 458M/130M/33M | sit |
+ | DiT-XL/2 (ImageNet-256) | 675M | dit |
+ | DiT-L/2, DiT-B/2, DiT-S/2 | 458M/130M/33M | dit |
+
| Text-to-Image | stable-diffusion-3.5-medium/large | 2.5B/8.1B | sd3-5 |
| FLUX.1-dev | 13B | flux1 |
| Z-Image-Turbo | 6B | z-image |
@@ -158,6 +163,22 @@ The unified structure of dataset is:
For text-to-image and text-to-video tasks, the only required input is the **prompt** in plain text format. Use `train.txt` and `test.txt` (optional) with following format:
+## Class-to-Image (SiT / DiT)
+
+For class-conditional generation with SiT or DiT, each prompt is an **integer class index string** (e.g., ImageNet class label). Use `train.jsonl` with the following format:
+
+```jsonl
+{"prompt": "985"}
+{"prompt": "207"}
+{"prompt": "388"}
+```
+
+> Class indices follow the standard ImageNet-1K ordering (0–999). Example: 985 = daisy, 207 = golden retriever, 388 = giant panda.
+
+The `model_name_or_path` directory must contain a `config.json` describing the model variant and VAE path. See [`examples/grpo/lora/sit_xl2.yaml`](examples/grpo/lora/sit_xl2.yaml) for a complete configuration example.
+
+## Text-to-Image & Text-to-Video (original)
+
```
A hill in a sunset.
An astronaut riding a horse on Mars.
diff --git a/examples/grpo/lora/sit_xl2.yaml b/examples/grpo/lora/sit_xl2.yaml
new file mode 100644
index 00000000..c6bfaed2
--- /dev/null
+++ b/examples/grpo/lora/sit_xl2.yaml
@@ -0,0 +1,107 @@
+# Environment Configuration
+launcher: "accelerate"
+config_file: config/deepspeed/deepspeed_zero2.yaml
+num_processes: 8
+main_process_port: 29500
+mixed_precision: "bf16"
+
+run_name: null
+project: "Flow-Factory"
+logging_backend: "wandb"
+
+# Data Configuration
+# Dataset items should have "prompt" field containing the integer class label string,
+# e.g. {"prompt": "985"} for ImageNet class 985 (daisy).
+data:
+ dataset_dir: "dataset/imagenet_cls"
+ preprocessing_batch_size: 8
+ dataloader_num_workers: 16
+ force_reprocess: true
+ cache_dir: "~/.cache/flow_factory/datasets"
+ max_dataset_size: 1000
+
+# Model Configuration
+model:
+ finetune_type: 'lora'
+ lora_rank: 64
+ lora_alpha: 128
+ target_modules: "default"
+ # model_name_or_path must be a local directory containing:
+ # config.json — model config (see SiTDiTAdapter docstring)
+ # model.safetensors or pytorch_model.bin — transformer weights
+ model_name_or_path: "/path/to/sit_xl2_imagenet"
+ model_type: "sit" # or "dit" (both use the same SiTDiTAdapter)
+ target_components: ["transformer"]
+ resume_path: null
+ resume_type: null
+
+log:
+ save_dir: "~/Flow-Factory"
+ save_freq: 20
+ save_model_only: true
+
+# Training Configuration
+train:
+ trainer_type: 'grpo'
+ advantage_aggregation: 'gdpo'
+ clip_range: 1.0e-4
+ adv_clip_range: 5.0
+ kl_type: 'v-based'
+ kl_beta: 0
+ ref_param_device: 'cuda'
+
+ resolution: 256 # SiT-XL/2 default resolution (256x256 images)
+ num_inference_steps: 20
+ guidance_scale: 4.0 # class-conditional CFG scale (typical: 1.5–5.0)
+
+ per_device_batch_size: 2
+ group_size: 16
+ global_std: false
+ unique_sample_num_per_epoch: 48
+ gradient_step_per_epoch: 2
+
+ learning_rate: 3.0e-4
+ adam_weight_decay: 1.0e-4
+ adam_betas: [0.9, 0.999]
+ adam_epsilon: 1.0e-8
+ max_grad_norm: 1.0
+
+ ema_decay: 0.9
+ ema_update_interval: 4
+ ema_device: "cuda"
+
+ enable_gradient_checkpointing: false
+ seed: 42
+
+# Scheduler Configuration
+scheduler:
+ dynamics_type: "Flow-SDE"
+ noise_level: 0.7
+ num_sde_steps: 1
+ sde_steps: [1, 2, 3]
+ seed: 42
+
+# Evaluation settings
+eval:
+ resolution: 256
+ per_device_batch_size: 1
+ guidance_scale: 4.0
+ num_inference_steps: 50
+ eval_freq: 20
+ seed: 42
+
+# Reward Model Configuration
+rewards:
+ - name: "pickscore_rank"
+ reward_model: "PickScore_Rank"
+ weight: 1.0
+ batch_size: 16
+ device: "cuda"
+ dtype: bfloat16
+
+eval_rewards:
+ name: "pickscore"
+ reward_model: "PickScore"
+ batch_size: 16
+ device: "cuda"
+ dtype: bfloat16
diff --git a/src/flow_factory/models/registry.py b/src/flow_factory/models/registry.py
index 339499a5..0f78ade9 100644
--- a/src/flow_factory/models/registry.py
+++ b/src/flow_factory/models/registry.py
@@ -38,6 +38,8 @@
'wan2_i2v': 'flow_factory.models.wan.wan2_i2v.Wan2_I2V_Adapter',
'wan2_t2v': 'flow_factory.models.wan.wan2_t2v.Wan2_T2V_Adapter',
'wan2_v2v': 'flow_factory.models.wan.wan2_v2v.Wan2_V2V_Adapter',
+ 'sit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter',
+ 'dit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter',
}
def get_model_adapter_class(identifier: str) -> Type:
diff --git a/src/flow_factory/models/sit_dit/__init__.py b/src/flow_factory/models/sit_dit/__init__.py
new file mode 100644
index 00000000..e86c0cdd
--- /dev/null
+++ b/src/flow_factory/models/sit_dit/__init__.py
@@ -0,0 +1,3 @@
+from .sit_dit import SiTDiTAdapter, SiTDiTSample
+
+__all__ = ["SiTDiTAdapter", "SiTDiTSample"]
diff --git a/src/flow_factory/models/sit_dit/models.py b/src/flow_factory/models/sit_dit/models.py
new file mode 100644
index 00000000..841e0fea
--- /dev/null
+++ b/src/flow_factory/models/sit_dit/models.py
@@ -0,0 +1,299 @@
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# SiT: https://github.com/willisma/SiT
+# --------------------------------------------------------
+
+import math
+from typing import Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+# ---------------------------------------------------------------------------
+# Embedding layers
+# ---------------------------------------------------------------------------
+
+class TimestepEmbedder(nn.Module):
+ """Embeds scalar timesteps into vector representations."""
+
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ return self.mlp(t_freq)
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations.
+ Handles label dropout for classifier-free guidance.
+ The null label index is num_classes (the extra embedding row).
+ """
+
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels: torch.Tensor, force_drop_ids=None) -> torch.Tensor:
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels: torch.Tensor, train: bool, force_drop_ids=None) -> torch.Tensor:
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ return self.embedding_table(labels)
+
+
+# ---------------------------------------------------------------------------
+# SiT / DiT blocks
+# ---------------------------------------------------------------------------
+
+class SiTBlock(nn.Module):
+ """SiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""
+
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True),
+ )
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
+ self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """The final layer of SiT."""
+
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True),
+ )
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ return self.linear(x)
+
+
+# ---------------------------------------------------------------------------
+# SiT / DiT model
+# ---------------------------------------------------------------------------
+
+class SiT(nn.Module):
+ """
+ Scalable Interpolant Transformer (SiT) / Diffusion Transformer (DiT).
+
+ Both architectures share this backbone; the difference lies in the
+ training objective (flow-matching velocity for SiT, noise prediction
+ for DiT) which is handled outside the model.
+ """
+
+ def __init__(
+ self,
+ input_size: int = 32,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ hidden_size: int = 1152,
+ depth: int = 28,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.0,
+ class_dropout_prob: float = 0.1,
+ num_classes: int = 1000,
+ learn_sigma: bool = True,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
+ num_patches = self.x_embedder.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
+ ])
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
+ int(self.x_embedder.num_patches ** 0.5))
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+ x = x.reshape(x.shape[0], h, w, p, p, c)
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ return x.reshape(x.shape[0], c, h * p, h * p)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: (B, C, H, W) noisy latents
+ t: (B,) timesteps in [0, 1]
+ y: (B,) integer class labels
+ Returns:
+ (B, C, H, W) predicted velocity (SiT) or noise (DiT)
+ """
+ x = self.x_embedder(x) + self.pos_embed
+ t = self.t_embedder(t)
+ y = self.y_embedder(y, self.training)
+ c = t + y
+ for block in self.blocks:
+ x = block(x, c)
+ x = self.final_layer(x, c)
+ x = self.unpatchify(x)
+ if self.learn_sigma:
+ x, _ = x.chunk(2, dim=1)
+ return x
+
+
+# ---------------------------------------------------------------------------
+# Positional embedding helpers
+# ---------------------------------------------------------------------------
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int,
+ cls_token: bool = False, extra_tokens: int = 0) -> np.ndarray:
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h)
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray:
+ assert embed_dim % 2 == 0
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
+ return np.concatenate([emb_h, emb_w], axis=1)
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray:
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega
+ pos = pos.reshape(-1)
+ out = np.einsum('m,d->md', pos, omega)
+ emb = np.concatenate([np.sin(out), np.cos(out)], axis=1)
+ return emb
+
+
+# ---------------------------------------------------------------------------
+# Model configs
+# ---------------------------------------------------------------------------
+
+def SiT_XL_2(**kw): return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kw)
+def SiT_XL_4(**kw): return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kw)
+def SiT_XL_8(**kw): return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kw)
+def SiT_L_2(**kw): return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kw)
+def SiT_L_4(**kw): return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kw)
+def SiT_L_8(**kw): return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kw)
+def SiT_B_2(**kw): return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kw)
+def SiT_B_4(**kw): return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kw)
+def SiT_B_8(**kw): return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kw)
+def SiT_S_2(**kw): return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kw)
+def SiT_S_4(**kw): return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kw)
+def SiT_S_8(**kw): return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kw)
+
+SiT_models: Dict[str, type] = {
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
+ # DiT aliases (same architecture, different training objective)
+ 'DiT-XL/2': SiT_XL_2, 'DiT-XL/4': SiT_XL_4, 'DiT-XL/8': SiT_XL_8,
+ 'DiT-L/2': SiT_L_2, 'DiT-L/4': SiT_L_4, 'DiT-L/8': SiT_L_8,
+ 'DiT-B/2': SiT_B_2, 'DiT-B/4': SiT_B_4, 'DiT-B/8': SiT_B_8,
+ 'DiT-S/2': SiT_S_2, 'DiT-S/4': SiT_S_4, 'DiT-S/8': SiT_S_8,
+}
diff --git a/src/flow_factory/models/sit_dit/sit_dit.py b/src/flow_factory/models/sit_dit/sit_dit.py
new file mode 100644
index 00000000..a2adaf66
--- /dev/null
+++ b/src/flow_factory/models/sit_dit/sit_dit.py
@@ -0,0 +1,591 @@
+# Copyright 2026 Jayce-Ping
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# src/flow_factory/models/sit_dit/sit_dit.py
+"""
+SiT / DiT adapter for Flow-Factory.
+
+Architecture: class-conditional image generation on 256x256 (ImageNet-style).
+The transformer is the SiT/DiT backbone; the VAE is an SD-VAE compatible encoder.
+
+Timestep conventions
+--------------------
+- Flow-Factory scheduler: t ∈ [0, 1000], t=1000 ↔ pure noise
+- SiT model: t ∈ [0, 1], t=0 ↔ noise, t=1 ↔ data
+- Conversion: t_sit = 1 - t_ff / 1000
+
+Sign convention for scheduler
+------------------------------
+Flow-Factory: next = x + noise_pred * dt, dt = sigma_next - sigma_curr < 0
+SiT velocity points toward data → noise_pred = -v_sit
+
+CFG for class conditioning
+--------------------------
+ v_pred = v_uncond + cfg_scale * (v_cond - v_uncond)
+ null label index = num_classes (extra embedding row)
+"""
+from __future__ import annotations
+
+import os
+import json
+from typing import Union, List, Dict, Any, Optional, Tuple, Literal, ClassVar
+
+import numpy as np
+from dataclasses import dataclass
+from PIL import Image
+
+import torch
+import torch.nn as nn
+from diffusers import AutoencoderKL
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from accelerate import Accelerator
+
+from ...samples import T2ISample
+from ..abc import BaseAdapter
+from ...hparams import *
+from ...scheduler import (
+ FlowMatchEulerDiscreteSDEScheduler,
+ FlowMatchEulerDiscreteSDESchedulerOutput,
+ SDESchedulerOutput,
+ set_scheduler_timesteps,
+)
+from ...utils.base import filter_kwargs
+from ...utils.trajectory_collector import (
+ TrajectoryCollector,
+ CallbackCollector,
+ TrajectoryIndicesType,
+ create_trajectory_collector,
+ create_callback_collector,
+)
+from ...utils.logger_utils import setup_logger
+from .models import SiT_models
+
+logger = setup_logger(__name__)
+
+# SD-VAE scale factor (same as Stable Diffusion)
+_SCALE_FACTOR = 0.18215
+
+
+# ---------------------------------------------------------------------------
+# Pipeline container
+# ---------------------------------------------------------------------------
+
+class SiTDiTPipeline:
+ """
+ Minimal pipeline container for SiT / DiT models.
+
+ Attributes
+ ----------
+ transformer : SiT
+ The SiT / DiT transformer.
+ vae : AutoencoderKL
+ Diffusers VAE for encoding / decoding latents.
+ scheduler : FlowMatchEulerDiscreteScheduler
+ Initial scheduler (replaced by the SDE variant during BaseAdapter init).
+ new_mean : Optional[float]
+ ReaLS latent normalization mean. None for standard SD-VAE latents.
+ new_std : Optional[float]
+ ReaLS latent normalization std. None for standard SD-VAE latents.
+ """
+
+ def __init__(
+ self,
+ transformer: nn.Module,
+ vae: AutoencoderKL,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ new_mean: Optional[float] = None,
+ new_std: Optional[float] = None,
+ ):
+ self.transformer = transformer
+ self.vae = vae
+ self.scheduler = scheduler
+ self.new_mean = new_mean
+ self.new_std = new_std
+
+ def maybe_free_model_hooks(self):
+ """No-op — kept for API compatibility with BaseAdapter."""
+ pass
+
+
+# ---------------------------------------------------------------------------
+# Sample dataclass
+# ---------------------------------------------------------------------------
+
+@dataclass
+class SiTDiTSample(T2ISample):
+ """Output sample for SiT / DiT generation."""
+ # Instance variables
+ class_labels: Optional[torch.Tensor] = None # () scalar integer class label (per sample)
+
+
+# ---------------------------------------------------------------------------
+# Adapter
+# ---------------------------------------------------------------------------
+
+class SiTDiTAdapter(BaseAdapter):
+ """
+ Flow-Factory adapter for SiT (Scalable Interpolant Transformer) and
+ DiT (Diffusion Transformer) class-conditional image generation.
+
+ Config JSON format (``model_name_or_path/config.json``)
+ --------------------------------------------------------
+ {
+ "model_name": "SiT-XL/2", // key in SiT_models dict
+ "num_classes": 1000, // number of classes (ImageNet: 1000)
+ "learn_sigma": false,
+ "input_size": 32, // spatial size of latent (image/8)
+ "in_channels": 4, // VAE latent channels
+ "vae_path": "stabilityai/sd-vae-ft-ema", // HF id or local path
+ "reals_vae": false, // true → use mean_std.json normalization
+ "mean_std_path": null // path to mean_std.json (if reals_vae=true)
+ }
+
+ Weights are loaded from ``model_name_or_path/pytorch_model.bin`` or
+ ``model_name_or_path/model.safetensors``.
+ """
+
+ def __init__(self, config: Arguments, accelerator: Accelerator):
+ super().__init__(config, accelerator)
+ self.pipeline: SiTDiTPipeline
+ self.scheduler: FlowMatchEulerDiscreteSDEScheduler
+
+ # ========================== Pipeline Loading ==========================
+
+ def load_pipeline(self) -> SiTDiTPipeline:
+ path = self.model_args.model_name_or_path
+
+ # -- Load model config ------------------------------------------------
+ config_path = os.path.join(path, "config.json")
+ with open(config_path, "r") as f:
+ cfg = json.load(f)
+
+ model_name = cfg.get("model_name", "SiT-XL/2")
+ num_classes = cfg.get("num_classes", 1000)
+ learn_sigma = cfg.get("learn_sigma", False)
+ input_size = cfg.get("input_size", 32)
+ in_channels = cfg.get("in_channels", 4)
+ vae_path = cfg.get("vae_path", "stabilityai/sd-vae-ft-ema")
+ reals_vae = cfg.get("reals_vae", False)
+ mean_std_path = cfg.get("mean_std_path", None)
+
+ # Cache for decode_latents
+ self._num_classes = num_classes
+ self._reals_vae = reals_vae
+
+ # -- Load SiT / DiT transformer ---------------------------------------
+ if model_name not in SiT_models:
+ raise ValueError(
+ f"Unknown model_name '{model_name}'. "
+ f"Available: {list(SiT_models.keys())}"
+ )
+ transformer = SiT_models[model_name](
+ input_size=input_size,
+ in_channels=in_channels,
+ num_classes=num_classes,
+ learn_sigma=learn_sigma,
+ )
+
+ # Try safetensors first, then pytorch_model.bin
+ weights_path = None
+ for fname in ("model.safetensors", "pytorch_model.bin"):
+ candidate = os.path.join(path, fname)
+ if os.path.isfile(candidate):
+ weights_path = candidate
+ break
+
+ if weights_path is not None:
+ if weights_path.endswith(".safetensors"):
+ from safetensors.torch import load_file as load_safetensors
+ state_dict = load_safetensors(weights_path)
+ else:
+ state_dict = torch.load(weights_path, map_location="cpu")
+
+ # Support bare state dict or nested {"ema": ..., "model": ...}
+ if isinstance(state_dict, dict):
+ if "ema" in state_dict:
+ state_dict = state_dict["ema"]
+ elif "model" in state_dict:
+ state_dict = state_dict["model"]
+
+ missing, unexpected = transformer.load_state_dict(state_dict, strict=True)
+ if missing:
+ logger.warning(f"Missing keys when loading transformer: {missing[:5]}...")
+ if unexpected:
+ logger.warning(f"Unexpected keys when loading transformer: {unexpected[:5]}...")
+ logger.info(f"Loaded transformer weights from {weights_path}")
+ else:
+ logger.warning(
+ f"No transformer weights found in {path}. "
+ "Starting from random initialization."
+ )
+
+ # -- Load VAE ---------------------------------------------------------
+ vae = AutoencoderKL.from_pretrained(vae_path)
+
+ # -- Load ReaLS latent normalization stats ----------------------------
+ new_mean = None
+ new_std = None
+ if reals_vae:
+ if mean_std_path is None:
+ mean_std_path = os.path.join(path, "mean_std.json")
+ if os.path.isfile(mean_std_path):
+ with open(mean_std_path, "r") as f:
+ stats = json.load(f)
+ new_mean = float(stats["mean"])
+ new_std = float(stats["std"])
+ logger.info(
+ f"Loaded ReaLS latent stats: mean={new_mean:.4f}, std={new_std:.4f}"
+ )
+ else:
+ raise FileNotFoundError(
+ f"reals_vae=True but mean_std.json not found at {mean_std_path}. "
+ "Set 'mean_std_path' in config.json or place mean_std.json in "
+ "model_name_or_path/."
+ )
+
+ # -- Build scheduler (placeholder; replaced by load_scheduler) --------
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ num_train_timesteps=1000,
+ shift=1.0,
+ )
+
+ return SiTDiTPipeline(
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ new_mean=new_mean,
+ new_std=new_std,
+ )
+
+ # ========================== Default Modules ==========================
+
+ @property
+ def default_target_modules(self) -> List[str]:
+ """Default trainable modules for SiT / DiT (attention projections)."""
+ return [
+ "attn.to_q", "attn.to_k", "attn.to_v", "attn.to_out.0",
+ "attn.qkv",
+ ]
+
+ # ======================== Encoding & Decoding ========================
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ **kwargs,
+ ) -> Dict[str, Any]:
+ """
+ Parse class label(s) from string prompt(s).
+
+ Each prompt is interpreted as an integer class index (e.g. "0" or "985").
+ Returns dummy ``prompt_embeds`` / ``prompt_ids`` tensors so that the
+ BaseAdapter machinery that expects these fields does not crash.
+ """
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ device = self.device
+ labels = []
+ for p in prompt:
+ try:
+ labels.append(int(p.strip()))
+ except ValueError:
+ logger.warning(
+ f"Could not parse class label from prompt '{p}'. Using 0."
+ )
+ labels.append(0)
+
+ class_labels = torch.tensor(labels, dtype=torch.long, device=device)
+ batch_size = len(labels)
+
+ # Dummy tensors — kept for API compatibility; not used by forward()
+ dummy_embeds = torch.zeros(batch_size, 1, dtype=torch.float32, device=device)
+ dummy_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
+
+ return {
+ "class_labels": class_labels,
+ "prompt_embeds": dummy_embeds,
+ "prompt_ids": dummy_ids,
+ }
+
+ def encode_image(self, images) -> None:
+ """Not used for class-conditional generation."""
+ pass
+
+ def encode_video(self, videos) -> None:
+ """Not used for SiT / DiT."""
+ pass
+
+ def decode_latents(
+ self,
+ latents: torch.Tensor,
+ height: int,
+ width: int,
+ output_type: Literal["pil", "pt", "np"] = "pil",
+ ) -> Union[List[Image.Image], torch.Tensor, np.ndarray]:
+ """Decode latents to pixel images using the VAE."""
+ vae = self.pipeline.vae
+ dtype = vae.dtype if hasattr(vae, 'dtype') else latents.dtype
+ latents = latents.to(dtype=dtype)
+
+ if self._reals_vae and self.pipeline.new_mean is not None:
+ # ReaLS latent space: denormalize to raw VAE latents
+ new_mean = self.pipeline.new_mean
+ new_std = self.pipeline.new_std
+ z_vae = latents * new_std + new_mean
+ else:
+ # Standard SD-VAE: latents are scaled by SCALE_FACTOR during training
+ z_vae = latents / _SCALE_FACTOR
+
+ images = vae.decode(z_vae, return_dict=False)[0]
+
+ # Map from [-1, 1] to [0, 1]
+ images = (images / 2 + 0.5).clamp(0, 1)
+
+ if output_type == "pt":
+ return images # (B, C, H, W)
+
+ images_np = images.permute(0, 2, 3, 1).float().cpu().numpy()
+ if output_type == "np":
+ return images_np
+
+ # PIL output
+ images_uint8 = (images_np * 255).round().astype(np.uint8)
+ return [Image.fromarray(img) for img in images_uint8]
+
+ # ======================== Inference ========================
+
+ @torch.no_grad()
+ def inference(
+ self,
+ # Ordinary args
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 256,
+ width: int = 256,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 4.0,
+ generator: Optional[torch.Generator] = None,
+ # Pre-encoded prompt
+ prompt_ids: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ # Other args
+ compute_log_prob: bool = True,
+ extra_call_back_kwargs: List[str] = [],
+ trajectory_indices: TrajectoryIndicesType = "all",
+ ) -> List[SiTDiTSample]:
+ """Execute generation and return SiTDiTSample objects."""
+
+ device = self.device
+ transformer = self.transformer
+ dtype = self._inference_dtype
+
+ # 1. Encode prompts if not provided
+ if class_labels is None:
+ encoded = self.encode_prompt(prompt)
+ class_labels = encoded["class_labels"]
+ prompt_ids = encoded["prompt_ids"]
+ prompt_embeds = encoded["prompt_embeds"]
+ else:
+ class_labels = class_labels.to(device)
+
+ batch_size = class_labels.shape[0]
+
+ # 2. Prepare latents
+ vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
+ latent_height = height // vae_scale_factor
+ latent_width = width // vae_scale_factor
+ in_channels = transformer.in_channels
+
+ latents = torch.randn(
+ batch_size, in_channels, latent_height, latent_width,
+ dtype=dtype, device=device, generator=generator,
+ )
+
+ # 3. Set timesteps
+ timesteps = set_scheduler_timesteps(
+ scheduler=self.pipeline.scheduler,
+ num_inference_steps=num_inference_steps,
+ seq_len=latent_height * latent_width, # spatial token count (H*W)
+ device=device,
+ )
+
+ # 4. Denoising loop
+ latent_collector = create_trajectory_collector(trajectory_indices, num_inference_steps)
+ latents = self.cast_latents(latents, default_dtype=dtype)
+ latent_collector.collect(latents, step_idx=0)
+ if compute_log_prob:
+ log_prob_collector = create_trajectory_collector(trajectory_indices, num_inference_steps)
+ callback_collector = create_callback_collector(trajectory_indices, num_inference_steps)
+
+ for i, t in enumerate(timesteps):
+ current_noise_level = self.scheduler.get_noise_level_for_timestep(t)
+ t_next = (
+ timesteps[i + 1]
+ if i + 1 < len(timesteps)
+ else torch.tensor(0, device=device)
+ )
+ return_kwargs = list(
+ set(["next_latents", "log_prob", "noise_pred"] + extra_call_back_kwargs)
+ )
+ current_compute_log_prob = compute_log_prob and current_noise_level > 0
+
+ output = self.forward(
+ t=t,
+ t_next=t_next,
+ latents=latents,
+ class_labels=class_labels,
+ guidance_scale=guidance_scale,
+ compute_log_prob=current_compute_log_prob,
+ return_kwargs=return_kwargs,
+ noise_level=current_noise_level,
+ )
+
+ latents = self.cast_latents(output.next_latents, default_dtype=dtype)
+ latent_collector.collect(latents, i + 1)
+ if current_compute_log_prob:
+ log_prob_collector.collect(output.log_prob, i)
+
+ callback_collector.collect_step(
+ step_idx=i,
+ output=output,
+ keys=extra_call_back_kwargs,
+ capturable={"noise_level": current_noise_level},
+ )
+
+ # 5. Decode images
+ images = self.decode_latents(latents, height, width, output_type="pt")
+
+ # 6. Build sample list
+ extra_call_back_res = callback_collector.get_result()
+ callback_index_map = callback_collector.get_index_map()
+ all_latents = latent_collector.get_result()
+ latent_index_map = latent_collector.get_index_map()
+ all_log_probs = log_prob_collector.get_result() if compute_log_prob else None
+ log_prob_index_map = log_prob_collector.get_index_map() if compute_log_prob else None
+
+ samples = [
+ SiTDiTSample(
+ # Denoising trajectory
+ timesteps=timesteps,
+ all_latents=(
+ torch.stack([lat[b] for lat in all_latents], dim=0)
+ if all_latents else None
+ ),
+ log_probs=(
+ torch.stack([lp[b] for lp in all_log_probs], dim=0)
+ if all_log_probs else None
+ ),
+ latent_index_map=latent_index_map,
+ log_prob_index_map=log_prob_index_map,
+ # Prompt / class label
+ prompt=prompt[b] if isinstance(prompt, list) else prompt,
+ prompt_ids=prompt_ids[b] if prompt_ids is not None else None,
+ prompt_embeds=prompt_embeds[b] if prompt_embeds is not None else None,
+ class_labels=class_labels[b], # scalar per-sample label
+ # Image & metadata
+ height=height,
+ width=width,
+ image=images[b],
+ # Extra kwargs
+ extra_kwargs={
+ **{k: v[b] for k, v in extra_call_back_res.items()},
+ "callback_index_map": callback_index_map,
+ },
+ )
+ for b in range(batch_size)
+ ]
+
+ self.pipeline.maybe_free_model_hooks()
+ return samples
+
+ # ======================== Forward (Training) ========================
+
+ def forward(
+ self,
+ t: torch.Tensor,
+ latents: torch.Tensor,
+ # Class conditioning (used instead of prompt_embeds)
+ class_labels: torch.Tensor,
+ # Next timestep
+ t_next: Optional[torch.Tensor] = None,
+ next_latents: Optional[torch.Tensor] = None,
+ # CFG
+ guidance_scale: Union[float, List[float]] = 4.0,
+ noise_level: Optional[float] = None,
+ # Other
+ compute_log_prob: bool = True,
+ return_kwargs: List[str] = [
+ "noise_pred", "next_latents", "next_latents_mean", "std_dev_t", "dt", "log_prob"
+ ],
+ ) -> FlowMatchEulerDiscreteSDESchedulerOutput:
+ """
+ One denoising step.
+
+ Converts the Flow-Factory timestep to SiT convention, runs the
+ transformer with optional CFG, negates the velocity to match the
+ Flow-Factory scheduler sign convention, then calls scheduler.step().
+ """
+ device = latents.device
+ dtype = latents.dtype
+ batch_size = latents.shape[0]
+ transformer = self.transformer
+
+ # ------------------------------------------------------------------
+ # 1. Timestep conversion: t_ff ∈ [0,1000] → t_sit ∈ [0,1]
+ # ------------------------------------------------------------------
+ sigma_t = t.float() / 1000.0 # sigma: 1 = pure noise
+ t_sit = 1.0 - sigma_t # SiT: 0 = noise, 1 = data
+ t_sit_batch = t_sit.expand(batch_size).to(dtype=dtype, device=device)
+
+ # ------------------------------------------------------------------
+ # 2. Transformer forward pass (with optional CFG)
+ # ------------------------------------------------------------------
+ do_cfg = (guidance_scale > 1.0) and (class_labels is not None)
+
+ if do_cfg:
+ # Concatenate conditional and unconditional batches
+ null_labels = torch.full_like(class_labels, transformer.y_embedder.num_classes)
+ labels_double = torch.cat([class_labels, null_labels], dim=0)
+ latents_double = torch.cat([latents, latents], dim=0)
+ t_sit_double = t_sit_batch.repeat(2)
+
+ v_double = transformer(latents_double, t_sit_double, labels_double)
+ v_cond, v_uncond = v_double.chunk(2, dim=0)
+ v_pred = v_uncond + guidance_scale * (v_cond - v_uncond)
+ else:
+ v_pred = transformer(latents, t_sit_batch, class_labels)
+
+ # ------------------------------------------------------------------
+ # 3. Sign flip: Flow-Factory uses next = x + noise_pred * dt
+ # with dt < 0 (sigma decreasing).
+ # SiT velocity points toward data, so we negate.
+ # ------------------------------------------------------------------
+ noise_pred = -v_pred
+
+ # ------------------------------------------------------------------
+ # 4. Scheduler step
+ # ------------------------------------------------------------------
+ output = self.scheduler.step(
+ noise_pred=noise_pred,
+ timestep=t,
+ latents=latents,
+ timestep_next=t_next,
+ next_latents=next_latents,
+ compute_log_prob=compute_log_prob,
+ return_dict=True,
+ return_kwargs=return_kwargs,
+ noise_level=noise_level,
+ )
+ return output