From 83bed5128d7eeab60a8bb471a899075b35dca6d8 Mon Sep 17 00:00:00 2001 From: bluvoll Date: Mon, 9 Mar 2026 23:35:03 -0500 Subject: [PATCH] Initial Rectified Flow support. --- .gitignore | 4 +- Example.toml | 57 ++++++++++ mikazuki/app/models.py | 4 +- mikazuki/schema/dreambooth.ts | 34 ++++++ mikazuki/schema/lora-master.ts | 34 ++++++ mikazuki/utils/train_utils.py | 2 +- requirements.txt | 22 ++-- scripts/stable/library/train_util.py | 139 ++++++++++++++++++++--- scripts/stable/requirements.txt | 18 +-- scripts/stable/sdxl_train.py | 163 ++++++++++++++++++++++++++- scripts/stable/sdxl_train_network.py | 16 +++ scripts/stable/train_network.py | 138 ++++++++++++++++++++++- train.sh | 13 +++ 13 files changed, 599 insertions(+), 45 deletions(-) create mode 100644 Example.toml diff --git a/.gitignore b/.gitignore index 1beb4ea2..0c719dea 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,6 @@ config/presets/test*.toml tests/ huggingface/hub/models* -huggingface/hub/version_diffusers_cache.txt \ No newline at end of file +huggingface/hub/version_diffusers_cache.txt +huggingface/ +huggingface \ No newline at end of file diff --git a/Example.toml b/Example.toml new file mode 100644 index 00000000..5f78960f --- /dev/null +++ b/Example.toml @@ -0,0 +1,57 @@ +model_train_type = "sdxl-lora" +pretrained_model_name_or_path = "/home/bluvoll/stable-diffusion-webui-reForge/models/Stable-diffusion/Chenkin-RF-0.3-Final-000001.safetensors" +train_data_dir = "/home/bluvoll/Escritorio/anime/kanojos/" +prior_loss_weight = 1 +resolution = "1024,1024" +enable_bucket = true +min_bucket_reso = 256 +max_bucket_reso = 2048 +bucket_reso_steps = 64 +bucket_no_upscale = true +output_name = "Kanojos-Aki-trainer" +output_dir = "./output" +save_model_as = "safetensors" +save_precision = "bf16" +save_every_n_epochs = 1 +save_state = false +max_train_epochs = 30 +train_batch_size = 12 +gradient_checkpointing = true +gradient_accumulation_steps = 1 +network_train_unet_only = true +network_train_text_encoder_only = false +learning_rate = 0.00004 +unet_lr = 0.00004 +text_encoder_lr = 0 +lr_scheduler = "cosine_with_restarts" +lr_warmup_steps = 5 +loss_type = "l2" +lr_scheduler_num_cycles = 2 +optimizer_type = "pytorch_optimizer.CAME" +network_module = "lycoris.kohya" +network_dim = 32 +network_alpha = 32 +log_with = "tensorboard" +logging_dir = "./logs" +caption_extension = ".txt" +shuffle_caption = true +keep_tokens = 0 +max_token_length = 255 +caption_dropout_rate = 0.1 +caption_tag_dropout_rate = 0.1 +flow_model = true +flow_use_ot = true +flow_timestep_distribution = "uniform" +flow_uniform_static_ratio = 2 +contrastive_flow_matching = false +cfm_lambda = 0.05 +seed = 1337 +mixed_precision = "bf16" +full_bf16 = true +xformers = true +lowram = false +cache_latents = true +cache_latents_to_disk = true +persistent_data_loader_workers = true +network_args = [ "conv_dim=32", "conv_alpha=32", "dropout=0", "algo=locon" ] +optimizer_args = [ "weight_decay=0.03" ] diff --git a/mikazuki/app/models.py b/mikazuki/app/models.py index fe4c400d..d8cf066b 100644 --- a/mikazuki/app/models.py +++ b/mikazuki/app/models.py @@ -32,8 +32,8 @@ class TaggerInterrogateRequest(BaseModel): class APIResponse(BaseModel): status: str - message: Optional[str] - data: Optional[Dict] + message: Optional[str] = None + data: Optional[Dict] = None class APIResponseSuccess(APIResponse): diff --git a/mikazuki/schema/dreambooth.ts b/mikazuki/schema/dreambooth.ts index c6a71a31..0cfcbd56 100644 --- a/mikazuki/schema/dreambooth.ts +++ b/mikazuki/schema/dreambooth.ts @@ -184,6 +184,40 @@ Schema.intersect([ multires_noise_discount: Schema.number().step(0.1).description("多分辨率(金字塔)衰减率 推荐 0.3-0.8,须同时与上方参数 multires_noise_iterations 一同启用"), }).description("噪声设置"), + // Rectified Flow 设置 (仅 SDXL 微调) + Schema.union([ + Schema.object({ + model_train_type: Schema.const("sdxl-finetune").required(), + }).extend(Schema.intersect([ + Schema.object({ + flow_model: Schema.boolean().default(false).description("启用 Rectified Flow 训练目标(用于 RF 模型微调)"), + }).description("Rectified Flow 设置"), + + Schema.union([ + Schema.object({ + flow_model: Schema.const(true).required(), + flow_use_ot: Schema.boolean().default(false).description("使用余弦最优传输配对 latent 和噪声"), + flow_timestep_distribution: Schema.union(["logit_normal", "uniform"]).default("logit_normal").description("时间步采样分布"), + flow_uniform_static_ratio: Schema.number().step(0.1).description("固定的时间步偏移比率(例如 2),留空不使用"), + contrastive_flow_matching: Schema.boolean().default(false).description("启用对比流匹配 (ΔFM) 目标"), + cfm_lambda: Schema.number().step(0.01).default(0.05).description("ΔFM 损失中对比项的权重"), + }), + Schema.object({}), + ]), + + Schema.union([ + Schema.object({ + flow_model: Schema.const(true).required(), + flow_timestep_distribution: Schema.const("logit_normal").required(), + flow_logit_mean: Schema.number().step(0.1).default(0.0).description("logit-normal 分布的均值"), + flow_logit_std: Schema.number().step(0.1).default(1.0).description("logit-normal 分布的标准差"), + }), + Schema.object({}), + ]), + ])), + Schema.object({}), + ]), + Schema.object({ seed: Schema.number().default(1337).description("随机种子"), clip_skip: Schema.number().role("slider").min(0).max(12).step(1).default(2).description("CLIP 跳过层数 *玄学*"), diff --git a/mikazuki/schema/lora-master.ts b/mikazuki/schema/lora-master.ts index 50e731a8..85b391c1 100644 --- a/mikazuki/schema/lora-master.ts +++ b/mikazuki/schema/lora-master.ts @@ -82,6 +82,40 @@ Schema.intersect([ // 噪声设置 SHARED_SCHEMAS.NOISE_SETTINGS, + // Rectified Flow 设置 (仅 SDXL LoRA) + Schema.union([ + Schema.object({ + model_train_type: Schema.const("sdxl-lora").required(), + }).extend(Schema.intersect([ + Schema.object({ + flow_model: Schema.boolean().default(false).description("启用 Rectified Flow 训练目标(用于 RF 模型微调)"), + }).description("Rectified Flow 设置"), + + Schema.union([ + Schema.object({ + flow_model: Schema.const(true).required(), + flow_use_ot: Schema.boolean().default(false).description("使用余弦最优传输配对 latent 和噪声"), + flow_timestep_distribution: Schema.union(["logit_normal", "uniform"]).default("logit_normal").description("时间步采样分布"), + flow_uniform_static_ratio: Schema.number().step(0.1).description("固定的时间步偏移比率(例如 2),留空不使用"), + contrastive_flow_matching: Schema.boolean().default(false).description("启用对比流匹配 (ΔFM) 目标"), + cfm_lambda: Schema.number().step(0.01).default(0.05).description("ΔFM 损失中对比项的权重"), + }), + Schema.object({}), + ]), + + Schema.union([ + Schema.object({ + flow_model: Schema.const(true).required(), + flow_timestep_distribution: Schema.const("logit_normal").required(), + flow_logit_mean: Schema.number().step(0.1).default(0.0).description("logit-normal 分布的均值"), + flow_logit_std: Schema.number().step(0.1).default(1.0).description("logit-normal 分布的标准差"), + }), + Schema.object({}), + ]), + ])), + Schema.object({}), + ]), + // 数据增强 SHARED_SCHEMAS.DATA_ENCHANCEMENT, diff --git a/mikazuki/utils/train_utils.py b/mikazuki/utils/train_utils.py index bea4d6f3..dc43d5fb 100644 --- a/mikazuki/utils/train_utils.py +++ b/mikazuki/utils/train_utils.py @@ -252,7 +252,7 @@ def get_total_images(path, recursive=True): def fix_config_types(config: dict): - keep_float_params = ["guidance_scale", "sigmoid_scale", "discrete_flow_shift"] + keep_float_params = ["guidance_scale", "sigmoid_scale", "discrete_flow_shift", "flow_uniform_static_ratio", "flow_logit_mean", "flow_logit_std", "cfm_lambda"] for k in keep_float_params: if k in config: config[k] = float(config[k]) diff --git a/requirements.txt b/requirements.txt index 8b19f9a9..7fccbe08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,26 @@ -accelerate==0.33.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 +accelerate==1.10.1 +transformers==4.55.2 +diffusers[torch]==0.35.1 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.46.0 +bitsandbytes==0.48.1 lion-pytorch==0.1.2 schedulefree==1.4 -pytorch-optimizer==3.7.0 +pytorch-optimizer==3.8.0 prodigy-plus-schedule-free==1.9.0 prodigyopt==1.1.2 -tensorboard==2.10.1 -safetensors==0.4.4 +tensorboard==2.20.0 +safetensors==0.6.2 prodigy-plus-schedule-free # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.24.5 +huggingface-hub==0.34.4 # for Image utils imagesize==1.4.1 # for T5XXL tokenizer (SD3/FLUX) @@ -34,9 +34,9 @@ pillow numpy==1.26.4 # <=2.0.0 gradio==3.44.2 -fastapi==0.95.1 -uvicorn==0.22.0 -wandb==0.16.2 +fastapi==0.104.1 +uvicorn==0.34.0 +wandb==0.21.1 httpx==0.24.1 # extra open-clip-torch==2.20.0 diff --git a/scripts/stable/library/train_util.py b/scripts/stable/library/train_util.py index 100ef475..41854fbe 100644 --- a/scripts/stable/library/train_util.py +++ b/scripts/stable/library/train_util.py @@ -5195,8 +5195,49 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") +def cosine_optimal_transport(X: torch.Tensor, Y: torch.Tensor, backend: str = "auto"): + """Compute an optimal assignment under cosine distance.""" + + X_norm = X / torch.norm(X, dim=1, keepdim=True) + Y_norm = Y / torch.norm(Y, dim=1, keepdim=True) + cost = -torch.mm(X_norm, Y_norm.t()) + + if backend == "cuda": + return _cuda_assignment(cost) + if backend == "scipy": + return _scipy_assignment(cost) + + try: + return _cuda_assignment(cost) + except (ImportError, RuntimeError): + return _scipy_assignment(cost) + + +def _cuda_assignment(cost: torch.Tensor): + from torch_linear_assignment import assignment_to_indices, batch_linear_assignment # type: ignore + + assignment = batch_linear_assignment(cost.unsqueeze(0)) + row_idx, col_idx = assignment_to_indices(assignment) + return cost, (row_idx, col_idx) + + +def _scipy_assignment(cost: torch.Tensor): + from scipy.optimize import linear_sum_assignment # type: ignore + + cost_np = cost.to(torch.float32).detach().cpu().numpy() + row_ind, col_ind = linear_sum_assignment(cost_np) + row = torch.from_numpy(row_ind).to(cost.device, torch.long) + col = torch.from_numpy(col_ind).to(cost.device, torch.long) + return cost, (row, col) + + +def get_timesteps_and_huber_c( + args, min_timestep, max_timestep, noise_scheduler, b_size, device, timesteps_override=None +): + if timesteps_override is not None: + timesteps = timesteps_override.to("cpu") + else: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": @@ -5220,7 +5261,13 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, return timesteps, huber_c -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): +def get_noise_noisy_latents_and_timesteps( + args, + noise_scheduler, + latents, + pre_sampled_timesteps=None, + pixel_counts=None, +): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -5234,23 +5281,83 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount ) - # Sample a random timestep for each image b_size = latents.shape[0] - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + flow_model_enabled = getattr(args, "flow_model", False) + + if flow_model_enabled: + timestep_max = noise_scheduler.config.num_train_timesteps + distribution = getattr(args, "flow_timestep_distribution", "logit_normal") + if distribution == "logit_normal": + logits = torch.normal( + mean=getattr(args, "flow_logit_mean", 0.0), + std=getattr(args, "flow_logit_std", 1.0), + size=(b_size,), + device=latents.device, + ) + sigmas = torch.sigmoid(logits) + elif distribution == "uniform": + sigmas = torch.rand((b_size,), device=latents.device) + else: + raise ValueError(f"Unknown flow_timestep_distribution: {distribution}") - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + shift_requested = ( + getattr(args, "flow_uniform_shift", False) or getattr(args, "flow_uniform_static_ratio", None) is not None + ) + if sigmas is not None and shift_requested: + static_ratio = getattr(args, "flow_uniform_static_ratio", None) + if static_ratio is not None: + if static_ratio <= 0: + raise ValueError("`flow_uniform_static_ratio` must be positive when used.") + ratios = torch.full((b_size,), float(static_ratio), device=latents.device, dtype=torch.float32) + else: + if pixel_counts is None: + raise ValueError("Resolution-dependent Rectified Flow shift requires pixel_counts.") + base_pixels = getattr(args, "flow_uniform_base_pixels", None) + if base_pixels is None or base_pixels <= 0: + raise ValueError("`flow_uniform_base_pixels` must be positive when using flow_uniform_shift.") + ratios = torch.sqrt( + torch.as_tensor(pixel_counts, device=latents.device, dtype=torch.float32) / float(base_pixels) + ) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + t_ref = sigmas + sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref) + + timesteps = torch.clamp((sigmas * timestep_max).long(), 0, timestep_max - 1) + _, huber_c = get_timesteps_and_huber_c( + args, + 0, + noise_scheduler.config.num_train_timesteps, + noise_scheduler, + b_size, + latents.device, + timesteps_override=timesteps, + ) else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + + if flow_model_enabled: + if getattr(args, "flow_use_ot", False) and b_size > 1: + with torch.no_grad(): + lat_flat = latents.view(b_size, -1) + noise_flat = noise.view(b_size, -1) + _, (_, col_indices) = cosine_optimal_transport(lat_flat, noise_flat) + noise = noise[col_indices.squeeze(0)] + + sigmas_view = sigmas.view(-1, 1, 1, 1) + noisy_latents = sigmas_view * noise + (1.0 - sigmas_view) * latents + else: + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) return noise, noisy_latents, timesteps, huber_c diff --git a/scripts/stable/requirements.txt b/scripts/stable/requirements.txt index e6e1bf6f..9b1d7f77 100644 --- a/scripts/stable/requirements.txt +++ b/scripts/stable/requirements.txt @@ -1,22 +1,22 @@ -accelerate==0.30.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 +accelerate==1.10.1 +transformers==4.55.2 +diffusers[torch]==0.35.1 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.44.0 +bitsandbytes==0.48.1 prodigyopt==1.0 lion-pytorch==0.0.6 -tensorboard -safetensors==0.4.2 +tensorboard==2.20.0 +safetensors==0.6.2 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.24.5 +huggingface-hub==0.34.4 # for Image utils imagesize==1.4.1 # for BLIP captioning @@ -32,10 +32,12 @@ imagesize==1.4.1 # for cuda 12.1(default 11.8) # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ -# this is for onnx: +# this is for onnx: # protobuf==3.20.3 # open clip for SDXL # open-clip-torch==2.20.0 +# for Rectified Flow OT +scipy # For logging rich==13.7.0 # for kohya_ss library diff --git a/scripts/stable/sdxl_train.py b/scripts/stable/sdxl_train.py index b533b274..98c2c93a 100644 --- a/scripts/stable/sdxl_train.py +++ b/scripts/stable/sdxl_train.py @@ -110,6 +110,63 @@ def train(args): not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if getattr(args, "flow_model", False): + logger.info("Using Rectified Flow training objective.") + if args.v_parameterization: + raise ValueError("`--flow_model` is incompatible with `--v_parameterization`; Rectified Flow already predicts velocity.") + if args.min_snr_gamma: + logger.warning("`--min_snr_gamma` is ignored when Rectified Flow is enabled.") + args.min_snr_gamma = None + if args.debiased_estimation_loss: + logger.warning("`--debiased_estimation_loss` is ignored when Rectified Flow is enabled.") + args.debiased_estimation_loss = False + if args.scale_v_pred_loss_like_noise_pred: + logger.warning("`--scale_v_pred_loss_like_noise_pred` is ignored when Rectified Flow is enabled.") + args.scale_v_pred_loss_like_noise_pred = False + if args.v_pred_like_loss: + logger.warning("`--v_pred_like_loss` is ignored when Rectified Flow is enabled.") + args.v_pred_like_loss = None + if getattr(args, "zero_terminal_snr", False): + logger.warning("`--zero_terminal_snr` is ignored when Rectified Flow is enabled (RF does not use the noise scheduler).") + args.zero_terminal_snr = False + if getattr(args, "ip_noise_gamma", None): + logger.warning("`--ip_noise_gamma` is ignored when Rectified Flow is enabled.") + args.ip_noise_gamma = None + if getattr(args, "noise_offset", None): + logger.warning("`--noise_offset` is ignored when Rectified Flow is enabled.") + args.noise_offset = None + if getattr(args, "multires_noise_iterations", None): + logger.warning("`--multires_noise_iterations` is ignored when Rectified Flow is enabled.") + args.multires_noise_iterations = None + if args.flow_use_ot: + logger.info("Using cosine optimal transport pairing for Rectified Flow batches.") + shift_enabled = args.flow_uniform_shift or args.flow_uniform_static_ratio is not None + if args.flow_timestep_distribution == "logit_normal": + if args.flow_logit_std <= 0: + raise ValueError("`--flow_logit_std` must be positive.") + logger.info( + "Rectified Flow timesteps sampled from logit-normal distribution with " + f"mean={args.flow_logit_mean}, std={args.flow_logit_std}." + ) + elif args.flow_timestep_distribution == "uniform": + logger.info("Rectified Flow timesteps sampled uniformly in [0, 1].") + else: + raise ValueError(f"Unknown Rectified Flow timestep distribution: {args.flow_timestep_distribution}") + if shift_enabled: + if args.flow_uniform_static_ratio is not None: + if args.flow_uniform_static_ratio <= 0: + raise ValueError("`--flow_uniform_static_ratio` must be positive.") + logger.info( + f"Applying Rectified Flow timestep shift with static ratio={args.flow_uniform_static_ratio}." + ) + else: + logger.info( + f"Applying resolution-dependent Rectified Flow timestep shift with base pixels={args.flow_uniform_base_pixels}." + ) + + if getattr(args, "contrastive_flow_matching", False) and not (args.v_parameterization or getattr(args, "flow_model", False)): + raise ValueError("`--contrastive_flow_matching` requires either v-parameterization or Rectified Flow.") + if args.block_lr: block_lrs = [float(lr) for lr in args.block_lr.split(",")] assert ( @@ -690,10 +747,24 @@ def optimizer_hook(parameter: torch.Tensor): vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + # Compute pixel counts for resolution-dependent RF shift + needs_dynamic_shift = ( + getattr(args, "flow_model", False) and getattr(args, "flow_uniform_shift", False) + and getattr(args, "flow_uniform_static_ratio", None) is None + ) + if needs_dynamic_shift: + if target_size is None: + raise ValueError( + "Resolution-dependent Rectified Flow shift requires target size information in the batch." + ) + pixel_counts = (target_size[:, 0] * target_size[:, 1]).to(latents.device, torch.float32) + else: + pixel_counts = None + # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + args, noise_scheduler, latents, pixel_counts=pixel_counts ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -702,7 +773,9 @@ def optimizer_hook(parameter: torch.Tensor): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - if args.v_parameterization: + if getattr(args, "flow_model", False): + target = noise - latents + elif args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) else: @@ -719,6 +792,18 @@ def optimizer_hook(parameter: torch.Tensor): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if getattr(args, "contrastive_flow_matching", False) and latents.size(0) > 1: + negative_latents = latents.roll(1, 0) + negative_noise = noise.roll(1, 0) + with torch.no_grad(): + if getattr(args, "flow_model", False): + target_negative = negative_noise - negative_latents + else: + target_negative = noise_scheduler.get_velocity(negative_latents, negative_noise, timesteps) + loss_contrastive = torch.nn.functional.mse_loss( + noise_pred.float(), target_negative.float(), reduction="none" + ) + loss = loss - args.cfm_lambda * loss_contrastive if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -734,9 +819,22 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + per_pixel_loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if getattr(args, "contrastive_flow_matching", False) and latents.size(0) > 1: + negative_latents = latents.roll(1, 0) + negative_noise = noise.roll(1, 0) + with torch.no_grad(): + if getattr(args, "flow_model", False): + target_negative = negative_noise - negative_latents + else: + target_negative = noise_scheduler.get_velocity(negative_latents, negative_noise, timesteps) + loss_contrastive = torch.nn.functional.mse_loss( + noise_pred.float(), target_negative.float(), reduction="none" + ) + per_pixel_loss = per_pixel_loss - args.cfm_lambda * loss_contrastive + loss = per_pixel_loss.mean() accelerator.backward(loss) @@ -939,6 +1037,63 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) + parser.add_argument( + "--flow_model", + action="store_true", + help="enable Rectified Flow training objective / 启用 Rectified Flow 训练目标", + ) + parser.add_argument( + "--flow_use_ot", + action="store_true", + help="pair latents and noise with cosine optimal transport / 使用余弦最优传输配对 latent 和噪声", + ) + parser.add_argument( + "--flow_timestep_distribution", + type=str, + default="logit_normal", + choices=["logit_normal", "uniform"], + help="sampling distribution over Rectified Flow sigmas / Rectified Flow 时间步采样分布", + ) + parser.add_argument( + "--flow_logit_mean", + type=float, + default=0.0, + help="mean of the logit-normal distribution / logit-normal 分布的均值", + ) + parser.add_argument( + "--flow_logit_std", + type=float, + default=1.0, + help="stddev of the logit-normal distribution / logit-normal 分布的标准差", + ) + parser.add_argument( + "--flow_uniform_shift", + action="store_true", + help="apply resolution-dependent shift to RF timesteps / 对 RF 时间步应用分辨率依赖偏移", + ) + parser.add_argument( + "--flow_uniform_base_pixels", + type=float, + default=1024.0 * 1024.0, + help="reference pixel count for resolution-dependent shift / 分辨率依赖偏移的基准像素数", + ) + parser.add_argument( + "--flow_uniform_static_ratio", + type=float, + default=None, + help="fixed shift ratio for RF timesteps (e.g. 2); overrides resolution-based shift / 固定的时间步偏移比率", + ) + parser.add_argument( + "--contrastive_flow_matching", + action="store_true", + help="enable Contrastive Flow Matching (ΔFM) objective / 启用对比流匹配目标", + ) + parser.add_argument( + "--cfm_lambda", + type=float, + default=0.05, + help="lambda weight for the contrastive term in ΔFM loss / ΔFM 损失中对比项的权重", + ) return parser diff --git a/scripts/stable/sdxl_train_network.py b/scripts/stable/sdxl_train_network.py index 83969bb1..ba5a9ed8 100644 --- a/scripts/stable/sdxl_train_network.py +++ b/scripts/stable/sdxl_train_network.py @@ -164,6 +164,22 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred + def get_flow_pixel_counts(self, args, batch, latents): + """SDXL: 计算像素数用于分辨率依赖的RF时间步偏移 / Calculate pixel counts for resolution-dependent RF timestep shift.""" + if ( + getattr(args, "flow_model", False) + and getattr(args, "flow_uniform_shift", False) + and getattr(args, "flow_uniform_static_ratio", None) is None + ): + target_size = batch.get("target_sizes_hw") + if target_size is None: + raise ValueError( + "分辨率依赖的Rectified Flow偏移需要batch中包含target size信息 / " + "Resolution-dependent Rectified Flow shift requires target size information in the batch." + ) + return (target_size[:, 0] * target_size[:, 1]).to(latents.device, torch.float32) + return None + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) diff --git a/scripts/stable/train_network.py b/scripts/stable/train_network.py index 7bf125dc..5b69c66a 100644 --- a/scripts/stable/train_network.py +++ b/scripts/stable/train_network.py @@ -134,6 +134,10 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def get_flow_pixel_counts(self, args, batch, latents): + """Calculate pixel counts for resolution-dependent RF timestep shift (SDXL override).""" + return None + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -142,6 +146,61 @@ def train(self, args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # Rectified Flow 验证 | Rectified Flow validation + if getattr(args, "flow_model", False): + logger.info("使用Rectified Flow训练目标 / Using Rectified Flow training objective.") + if args.v_parameterization: + raise ValueError("`--flow_model` 与 `--v_parameterization` 不兼容;Rectified Flow已预测速度 / `--flow_model` is incompatible with `--v_parameterization`; Rectified Flow already predicts velocity.") + if args.min_snr_gamma: + logger.warning("`--min_snr_gamma` 在Rectified Flow启用时被忽略 / `--min_snr_gamma` is ignored when Rectified Flow is enabled.") + args.min_snr_gamma = None + if args.debiased_estimation_loss: + logger.warning("`--debiased_estimation_loss` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.debiased_estimation_loss = False + if args.scale_v_pred_loss_like_noise_pred: + logger.warning("`--scale_v_pred_loss_like_noise_pred` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.scale_v_pred_loss_like_noise_pred = False + if getattr(args, "v_pred_like_loss", None): + logger.warning("`--v_pred_like_loss` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.v_pred_like_loss = None + if getattr(args, "zero_terminal_snr", False): + logger.warning("`--zero_terminal_snr` 在Rectified Flow启用时被忽略(RF不使用噪声调度器) / is ignored when Rectified Flow is enabled.") + args.zero_terminal_snr = False + if getattr(args, "ip_noise_gamma", None): + logger.warning("`--ip_noise_gamma` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.ip_noise_gamma = None + if getattr(args, "noise_offset", None): + logger.warning("`--noise_offset` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.noise_offset = None + if getattr(args, "multires_noise_iterations", None): + logger.warning("`--multires_noise_iterations` 在Rectified Flow启用时被忽略 / is ignored when Rectified Flow is enabled.") + args.multires_noise_iterations = None + if getattr(args, "flow_use_ot", False): + logger.info("使用余弦最优传输配对 / Using cosine optimal transport pairing for Rectified Flow batches.") + distribution = getattr(args, "flow_timestep_distribution", "logit_normal") + if distribution == "logit_normal": + if args.flow_logit_std <= 0: + raise ValueError("`--flow_logit_std` 必须为正数 / must be positive.") + logger.info( + f"Rectified Flow时间步采样: logit-normal分布, mean={args.flow_logit_mean}, std={args.flow_logit_std}" + ) + elif distribution == "uniform": + logger.info("Rectified Flow时间步采样: 均匀分布 / sampled uniformly in [0, 1].") + shift_enabled = getattr(args, "flow_uniform_shift", False) or getattr(args, "flow_uniform_static_ratio", None) is not None + if shift_enabled: + static_ratio = getattr(args, "flow_uniform_static_ratio", None) + if static_ratio is not None: + if static_ratio <= 0: + raise ValueError("`--flow_uniform_static_ratio` 必须为正数 / must be positive.") + logger.info(f"应用固定比率时间步偏移: ratio={static_ratio} / Applying static timestep shift.") + else: + logger.info( + f"应用分辨率依赖的时间步偏移, base_pixels={getattr(args, 'flow_uniform_base_pixels', 1024*1024)} / Applying resolution-dependent timestep shift." + ) + + if getattr(args, "contrastive_flow_matching", False) and not (args.v_parameterization or getattr(args, "flow_model", False)): + raise ValueError("`--contrastive_flow_matching` 需要v-parameterization或Rectified Flow / requires either v-parameterization or Rectified Flow.") + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None use_user_config = args.dataset_config is not None @@ -951,8 +1010,9 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified + pixel_counts = self.get_flow_pixel_counts(args, batch, latents) noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + args, noise_scheduler, latents, pixel_counts=pixel_counts ) # ensure the hidden state will require grad @@ -975,7 +1035,9 @@ def remove_model(old_ckpt_name): weight_dtype, ) - if args.v_parameterization: + if getattr(args, "flow_model", False): + target = noise - latents # Rectified Flow: 速度目标 = 噪声 - latent + elif args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) else: @@ -984,6 +1046,18 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if getattr(args, "contrastive_flow_matching", False) and latents.size(0) > 1: + negative_latents = latents.roll(1, 0) + negative_noise = noise.roll(1, 0) + with torch.no_grad(): + if getattr(args, "flow_model", False): + target_negative = negative_noise - negative_latents + else: + target_negative = noise_scheduler.get_velocity(negative_latents, negative_noise, timesteps) + loss_contrastive = torch.nn.functional.mse_loss( + noise_pred.float(), target_negative.float(), reduction="none" + ) + loss = loss - args.cfm_lambda * loss_contrastive if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -1228,6 +1302,66 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") + + # Rectified Flow 参数 | Rectified Flow arguments + parser.add_argument( + "--flow_model", + action="store_true", + help="启用Rectified Flow训练目标 / enable Rectified Flow training objective instead of standard diffusion", + ) + parser.add_argument( + "--flow_use_ot", + action="store_true", + help="使用余弦最优传输配对latent和噪声 / pair latents and noise with cosine optimal transport when using Rectified Flow", + ) + parser.add_argument( + "--flow_timestep_distribution", + type=str, + default="logit_normal", + choices=["logit_normal", "uniform"], + help="Rectified Flow的时间步采样分布 (默认: logit_normal) / sampling distribution over Rectified Flow sigmas", + ) + parser.add_argument( + "--flow_logit_mean", + type=float, + default=0.0, + help="logit-normal分布的均值 / mean of the logit-normal distribution when using Rectified Flow", + ) + parser.add_argument( + "--flow_logit_std", + type=float, + default=1.0, + help="logit-normal分布的标准差 / stddev of the logit-normal distribution when using Rectified Flow", + ) + parser.add_argument( + "--flow_uniform_shift", + action="store_true", + help="对Rectified Flow时间步应用分辨率依赖的偏移 / apply resolution-dependent shift to Rectified Flow timesteps (SD3-style)", + ) + parser.add_argument( + "--flow_uniform_base_pixels", + type=float, + default=1024.0 * 1024.0, + help="时间步偏移使用的基准像素数 / reference pixel count used for the resolution-dependent timestep shift", + ) + parser.add_argument( + "--flow_uniform_static_ratio", + type=float, + default=None, + help="使用固定的sqrt(m/n)比率进行时间步偏移 / use a fixed sqrt(m/n) ratio for Rectified Flow timestep shift; overrides resolution-based shift", + ) + parser.add_argument( + "--contrastive_flow_matching", + action="store_true", + help="启用对比流匹配(ΔFM)目标 / Enable Contrastive Flow Matching (ΔFM) objective. Works with v-parameterization or Rectified Flow.", + ) + parser.add_argument( + "--cfm_lambda", + type=float, + default=0.05, + help="ΔFM损失中对比项的权重 (默认: 0.05) / Lambda weight for the contrastive term in ΔFM loss", + ) + return parser diff --git a/train.sh b/train.sh index b18b9cc8..b93cccf5 100644 --- a/train.sh +++ b/train.sh @@ -29,6 +29,12 @@ noise_offset="0" # noise offset | 在训练中添加噪声偏移来改良生成 keep_tokens=0 # keep heading N tokens when shuffling caption tokens | 在随机打乱 tokens 时,保留前 N 个不变。 min_snr_gamma=0 # minimum signal-to-noise ratio (SNR) value for gamma-ray | 伽马射线事件的最小信噪比(SNR)值 默认为 0 +# Rectified Flow 设置 | Rectified Flow settings +use_flow_model=0 # 启用 Rectified Flow 训练目标 | enable Rectified Flow training objective +flow_use_ot=0 # 使用余弦最优传输配对 | use cosine optimal transport pairing +flow_timestep_distribution="logit_normal" # 时间步分布: logit_normal 或 uniform | timestep distribution +flow_uniform_static_ratio="" # 固定的时间步偏移比率 (例如 2) 留空不使用 | fixed timestep shift ratio (e.g. 2), leave empty to disable + # Learning rate | 学习率 lr="1e-4" # learning rate | 学习率,在分别设置下方 U-Net 和 文本编码器 的学习率时,该参数失效 unet_lr="1e-4" # U-Net learning rate | U-Net 学习率 @@ -123,6 +129,13 @@ if [[ $noise_offset != "0" ]]; then extArgs+=("--noise_offset $noise_offset"); f if [[ $min_snr_gamma -ne 0 ]]; then extArgs+=("--min_snr_gamma $min_snr_gamma"); fi +if [[ $use_flow_model == 1 ]]; then + extArgs+=("--flow_model") + extArgs+=("--flow_timestep_distribution $flow_timestep_distribution") + if [[ $flow_use_ot == 1 ]]; then extArgs+=("--flow_use_ot"); fi + if [[ $flow_uniform_static_ratio ]]; then extArgs+=("--flow_uniform_static_ratio $flow_uniform_static_ratio"); fi +fi + if [[ $use_wandb == 1 ]]; then extArgs+=("--log_with=all") if [[ $wandb_api_key ]]; then extArgs+=("--wandb_api_key $wandb_api_key"); fi