Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ This experimental feature leverages `diffusers`'s `transformer.set_attention_bac

<table>
<tr><th>Task</th><th>Model</th><th>Model Size</th><th>Model Type</th></tr>
<tr><td rowspan="4">Class-to-Image</td><td><a href="https://github.com/willisma/SiT">SiT-XL/2 (ImageNet-256)</a></td><td>675M</td><td>sit</td></tr>
<tr><td><a href="https://github.com/willisma/SiT">SiT-L/2, SiT-B/2, SiT-S/2</a></td><td>458M/130M/33M</td><td>sit</td></tr>
<tr><td><a href="https://github.com/facebookresearch/DiT">DiT-XL/2 (ImageNet-256)</a></td><td>675M</td><td>dit</td></tr>
<tr><td><a href="https://github.com/facebookresearch/DiT">DiT-L/2, DiT-B/2, DiT-S/2</a></td><td>458M/130M/33M</td><td>dit</td></tr>

<tr><td rowspan="6">Text-to-Image</td><td><a href="https://huggingface.co/collections/stabilityai/stable-diffusion-35">stable-diffusion-3.5-medium/large</a></td><td>2.5B/8.1B</td><td>sd3-5</td></tr>
<tr><td><a href="https://huggingface.co/black-forest-labs/FLUX.1-dev">FLUX.1-dev</a></td><td>13B</td><td>flux1</td></tr>
<tr><td><a href="https://huggingface.co/Tongyi-MAI/Z-Image-Turbo">Z-Image-Turbo</a></td><td>6B</td><td>z-image</td></tr>
Expand Down Expand Up @@ -158,6 +163,22 @@ The unified structure of dataset is:

For text-to-image and text-to-video tasks, the only required input is the **prompt** in plain text format. Use `train.txt` and `test.txt` (optional) with following format:

## Class-to-Image (SiT / DiT)

For class-conditional generation with SiT or DiT, each prompt is an **integer class index string** (e.g., ImageNet class label). Use `train.jsonl` with the following format:

```jsonl
{"prompt": "985"}
{"prompt": "207"}
{"prompt": "388"}
```

> Class indices follow the standard ImageNet-1K ordering (0–999). Example: 985 = daisy, 207 = golden retriever, 388 = giant panda.

The `model_name_or_path` directory must contain a `config.json` describing the model variant and VAE path. See [`examples/grpo/lora/sit_xl2.yaml`](examples/grpo/lora/sit_xl2.yaml) for a complete configuration example.

## Text-to-Image & Text-to-Video (original)

```
A hill in a sunset.
An astronaut riding a horse on Mars.
Expand Down
107 changes: 107 additions & 0 deletions examples/grpo/lora/sit_xl2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Environment Configuration
launcher: "accelerate"
config_file: config/deepspeed/deepspeed_zero2.yaml
num_processes: 8
main_process_port: 29500
mixed_precision: "bf16"

run_name: null
project: "Flow-Factory"
logging_backend: "wandb"

# Data Configuration
# Dataset items should have "prompt" field containing the integer class label string,
# e.g. {"prompt": "985"} for ImageNet class 985 (daisy).
data:
dataset_dir: "dataset/imagenet_cls"
preprocessing_batch_size: 8
dataloader_num_workers: 16
force_reprocess: true
cache_dir: "~/.cache/flow_factory/datasets"
max_dataset_size: 1000

# Model Configuration
model:
finetune_type: 'lora'
lora_rank: 64
lora_alpha: 128
target_modules: "default"
# model_name_or_path must be a local directory containing:
# config.json — model config (see SiTDiTAdapter docstring)
# model.safetensors or pytorch_model.bin — transformer weights
model_name_or_path: "/path/to/sit_xl2_imagenet"
model_type: "sit" # or "dit" (both use the same SiTDiTAdapter)
target_components: ["transformer"]
resume_path: null
resume_type: null

log:
save_dir: "~/Flow-Factory"
save_freq: 20
save_model_only: true

# Training Configuration
train:
trainer_type: 'grpo'
advantage_aggregation: 'gdpo'
clip_range: 1.0e-4
adv_clip_range: 5.0
kl_type: 'v-based'
kl_beta: 0
ref_param_device: 'cuda'

resolution: 256 # SiT-XL/2 default resolution (256x256 images)
num_inference_steps: 20
guidance_scale: 4.0 # class-conditional CFG scale (typical: 1.5–5.0)

per_device_batch_size: 2
group_size: 16
global_std: false
unique_sample_num_per_epoch: 48
gradient_step_per_epoch: 2

learning_rate: 3.0e-4
adam_weight_decay: 1.0e-4
adam_betas: [0.9, 0.999]
adam_epsilon: 1.0e-8
max_grad_norm: 1.0

ema_decay: 0.9
ema_update_interval: 4
ema_device: "cuda"

enable_gradient_checkpointing: false
seed: 42

# Scheduler Configuration
scheduler:
dynamics_type: "Flow-SDE"
noise_level: 0.7
num_sde_steps: 1
sde_steps: [1, 2, 3]
seed: 42

# Evaluation settings
eval:
resolution: 256
per_device_batch_size: 1
guidance_scale: 4.0
num_inference_steps: 50
eval_freq: 20
seed: 42

# Reward Model Configuration
rewards:
- name: "pickscore_rank"
reward_model: "PickScore_Rank"
weight: 1.0
batch_size: 16
device: "cuda"
dtype: bfloat16

eval_rewards:
name: "pickscore"
reward_model: "PickScore"
batch_size: 16
device: "cuda"
dtype: bfloat16
2 changes: 2 additions & 0 deletions src/flow_factory/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
'wan2_i2v': 'flow_factory.models.wan.wan2_i2v.Wan2_I2V_Adapter',
'wan2_t2v': 'flow_factory.models.wan.wan2_t2v.Wan2_T2V_Adapter',
'wan2_v2v': 'flow_factory.models.wan.wan2_v2v.Wan2_V2V_Adapter',
'sit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter',
'dit': 'flow_factory.models.sit_dit.sit_dit.SiTDiTAdapter',
}

def get_model_adapter_class(identifier: str) -> Type:
Expand Down
3 changes: 3 additions & 0 deletions src/flow_factory/models/sit_dit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sit_dit import SiTDiTAdapter, SiTDiTSample

__all__ = ["SiTDiTAdapter", "SiTDiTSample"]
Loading