diff --git a/README.md b/README.md index b988f9b3..cefd1087 100755 --- a/README.md +++ b/README.md @@ -11,13 +11,14 @@ Wan-Fun: English | [简体中文](./README_zh-CN.md) | [日本語](./README_ja-JP.md) # Table of Contents -- [Table of Contents](#table-of-contents) - [Introduction](#introduction) - [Quick Start](#quick-start) - [Video Result](#video-result) -- [How to use](#how-to-use) +- [How to Use](#how-to-use) - [Model zoo](#model-zoo) - [Reference](#reference) +- [Citation](#citation) +- [Limitations and Risks](#limitations-and-risks) - [License](#license) # Introduction @@ -699,6 +700,31 @@ V1.1: - ComfyUI-CameraCtrl-Wrapper: https://github.com/chaojie/ComfyUI-CameraCtrl-Wrapper - CameraCtrl: https://github.com/hehao13/CameraCtrl +# Citation + +If you use VideoX-Fun in your research or project, please cite it as follows: + +```bibtex +@misc{aigc_apps_VideoX_Fun_2026, + author = {aigc-apps}, + title = {VideoX-Fun: A Video Generation Pipeline for Diffusion Transformer}, + year = {2026}, + publisher = {GitHub}, + url = {https://github.com/aigc-apps/VideoX-Fun} +} +``` + +# Limitations and Risks + +- Generated videos may have artifacts or quality issues, especially in complex scenes. +- The model may struggle with fine details, text rendering, or specific artistic styles. +- Performance varies with input prompt quality, resolution, and other parameters. +- The technology could be misused to create misleading content (e.g., deepfakes). Users are responsible for ethical use. +- The model may reflect biases present in the training data. +- Users should respect privacy and copyright when using real people's images or videos. + +We encourage responsible use and recommend implementing safeguards in production environments. + # License This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). diff --git a/README_ja-JP.md b/README_ja-JP.md index 86d17cde..f5b224bf 100755 --- a/README_ja-JP.md +++ b/README_ja-JP.md @@ -11,13 +11,14 @@ Wan-Fun: [English](./README.md) | [简体中文](./README_zh-CN.md) | 日本語 # 目次 -- [目次](#目次) - [紹介](#紹介) - [クイックスタート](#クイックスタート) - [ビデオ結果](#ビデオ結果) - [使用方法](#使用方法) - [モデルの場所](#モデルの場所) - [参考文献](#参考文献) +- [引用](#引用) +- [制限とリスク](#制限とリスク) - [ライセンス](#ライセンス) # 紹介 @@ -699,6 +700,31 @@ V1.1: - ComfyUI-CameraCtrl-Wrapper: https://github.com/chaojie/ComfyUI-CameraCtrl-Wrapper - CameraCtrl: https://github.com/hehao13/CameraCtrl +# 引用 + +研究やプロジェクトでVideoX-Funを使用する場合は、以下の形式で引用してください: + +```bibtex +@misc{aigc_apps_VideoX_Fun_2026, + author = {aigc-apps}, + title = {VideoX-Fun: A Video Generation Pipeline for Diffusion Transformer}, + year = {2026}, + publisher = {GitHub}, + url = {https://github.com/aigc-apps/VideoX-Fun} +} +``` + +# 制限とリスク + +- 生成された動画には、特に複雑なシーンでアーティファクトや品質の問題がある場合があります。 +- モデルは、細かい詳細、テキストのレンダリング、または特定の芸術スタイルで苦労する場合があります。 +- パフォーマンスは、入力プロンプトの品質、解像度、その他のパラメータによって異なります。 +- この技術は、誤解を招くコンテンツ(例:ディープフェイク)を作成するために悪用される可能性があります。ユーザーは倫理的な使用に責任を持ちます。 +- モデルは、トレーニングデータに存在するバイアスを反映する可能性があります。 +- ユーザーは、実在の人物の画像や動画を使用する際、プライバシーと著作権を尊重する必要があります。 + +責任ある使用を推奨し、本番環境でのセーフガードの実装をお勧めします。 + # ライセンス このプロジェクトは[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)の下でライセンスされています。 diff --git a/README_zh-CN.md b/README_zh-CN.md index 045ecca1..c3e98f1c 100755 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -11,13 +11,14 @@ Wan-Fun: [English](./README.md) | 简体中文 | [日本語](./README_ja-JP.md) # 目录 -- [目录](#目录) - [简介](#简介) - [快速启动](#快速启动) - [视频作品](#视频作品) - [如何使用](#如何使用) - [模型地址](#模型地址) - [参考文献](#参考文献) +- [引用](#引用) +- [限制与风险](#限制与风险) - [许可证](#许可证) # 简介 @@ -689,6 +690,31 @@ V1.1: - ComfyUI-CameraCtrl-Wrapper: https://github.com/chaojie/ComfyUI-CameraCtrl-Wrapper - CameraCtrl: https://github.com/hehao13/CameraCtrl +# 引用 + +如果您在研究或项目中使用了 VideoX-Fun,请按以下格式引用: + +```bibtex +@misc{aigc_apps_VideoX_Fun_2026, + author = {aigc-apps}, + title = {VideoX-Fun: A Video Generation Pipeline for Diffusion Transformer}, + year = {2026}, + publisher = {GitHub}, + url = {https://github.com/aigc-apps/VideoX-Fun} +} +``` + +# 限制与风险 + +- 生成的视频可能存在伪影或质量问题,尤其在复杂场景中。 +- 模型在处理精细细节、文字渲染或特定艺术风格时可能有困难。 +- 性能因输入提示词质量、分辨率等参数而异。 +- 该技术可能被滥用于创建误导性内容(如深度伪造)。用户需对道德使用负责。 +- 模型可能反映训练数据中存在的偏见。 +- 用户在使用真人图片或视频时应尊重隐私和版权。 + +我们鼓励负责任地使用该技术,并建议在生产环境中实施安全措施。 + # 许可证 本项目采用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). diff --git a/examples/flux/predict_t2i.py b/examples/flux/predict_t2i.py index f3f867cf..fe7cf93a 100644 --- a/examples/flux/predict_t2i.py +++ b/examples/flux/predict_t2i.py @@ -70,7 +70,7 @@ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 weight_dtype = torch.bfloat16 prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" -negative_prompt = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " +negative_prompt = " " guidance_scale = 1.0 seed = 43 num_inference_steps = 50 diff --git a/examples/ltx2.3/predict_i2v.py b/examples/ltx2.3/predict_i2v.py new file mode 100644 index 00000000..87dbb5e8 --- /dev/null +++ b/examples/ltx2.3/predict_i2v.py @@ -0,0 +1,305 @@ +import os +import sys + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + Gemma3ForConditionalGeneration, + GemmaTokenizerFast, LTX2TextConnectors, Gemma3Processor, + LTX2VideoTransformer3DModel, LTX2VocoderWithBWE) +from videox_fun.pipeline import LTX2I2VPipeline +from videox_fun.utils import (register_auto_device_hook, + safe_enable_group_offload) +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, + save_videos_grid, + save_videos_with_audio_grid) + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_group_offload transfers internal layer groups between CPU/CUDA, +# balancing memory efficiency and speed between full-module and leaf-level offloading methods. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_group_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/LTX-2.3-Diffusers" +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [512, 768] +video_length = 121 +fps = 24 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +# If you want to generate from text, please set the validation_image_start = None and validation_image_end = None +validation_image_start = "asset/1.png" + +# prompts +prompt = "A brown dog barks on a sofa, sitting on a light-colored couch in a cozy room. Behind the dog, there is a framed painting on a shelf, surrounded by pink flowers. " +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts" +# CFG guidance scale for video and audio modality +guidance_scale = 3.0 +audio_guidance_scale = 7.0 +# Spatio-Temporal Guidance (STG) scale for video and audio +stg_scale = 1.0 +audio_stg_scale = 1.0 +# Modality isolation guidance scale for video and audio +modality_scale = 3.0 +audio_modality_scale = 3.0 +# Guidance rescale factor for video and audio to prevent overexposure +guidance_rescale = 0.7 +audio_guidance_rescale = 0.7 +spatio_temporal_guidance_blocks = [28] +seed = 43 +num_inference_steps = 50 +lora_weight = 0.55 +save_path = "samples/ltx2-videos-i2v" + +# Audio sample rate will be read from vocoder config +audio_sample_rate = 24000 + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) + +# Transformer +transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Video VAE +vae = AutoencoderKLLTX2Video.from_pretrained( + model_name, + subfolder="vae", + torch_dtype=weight_dtype, +) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Audio VAE +audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model_name, + subfolder="audio_vae", + torch_dtype=weight_dtype, +) + +# Get Processor +processor = Gemma3Processor.from_pretrained( + model_name, + subfolder="processor", +) + +# Get Tokenizer +tokenizer = processor.tokenizer + +# Get Text encoder +text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_name, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +text_encoder = text_encoder.eval() + +# Connectors +connectors = LTX2TextConnectors.from_pretrained( + model_name, + subfolder="connectors", + torch_dtype=weight_dtype, +) + +# Vocoder +vocoder = LTX2VocoderWithBWE.from_pretrained( + model_name, + subfolder="vocoder", + torch_dtype=weight_dtype, +) + +# Get Scheduler +Chosen_Scheduler = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = LTX2I2VPipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, +) + +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=list(transformer.transformer_blocks)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["scale_shift_table", "audio_scale_shift_table", "video_a2v_cross_attn_scale_shift_table", "audio_a2v_cross_attn_scale_shift_table", ""], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["scale_shift_table", "audio_scale_shift_table", "video_a2v_cross_attn_scale_shift_table", "audio_a2v_cross_attn_scale_shift_table", ""], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + output = pipeline( + image=Image.open(validation_image_start), + prompt=prompt, + negative_prompt=negative_prompt, + height=sample_size[0], + width=sample_size[1], + num_frames=video_length, + frame_rate=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + stg_scale=stg_scale, + modality_scale=modality_scale, + guidance_rescale=guidance_rescale, + audio_guidance_scale=audio_guidance_scale, + audio_stg_scale=audio_stg_scale, + audio_modality_scale=audio_modality_scale, + audio_guidance_rescale=audio_guidance_rescale, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + generator=generator, + output_type="pt", + ) + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +sample = output.videos +audio = output.audio + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + if video_length == 1: + video_path = os.path.join(save_path, prefix + ".png") + + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(video_path) + else: + video_path = os.path.join(save_path, prefix + ".mp4") + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) + save_videos_with_audio_grid(sample, audio, video_path, fps=fps, audio_sample_rate=sr) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/examples/ltx2.3/predict_t2v.py b/examples/ltx2.3/predict_t2v.py new file mode 100644 index 00000000..ae05f069 --- /dev/null +++ b/examples/ltx2.3/predict_t2v.py @@ -0,0 +1,300 @@ +import os +import sys + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + Gemma3ForConditionalGeneration, + GemmaTokenizerFast, LTX2TextConnectors, Gemma3Processor, + LTX2VideoTransformer3DModel, LTX2VocoderWithBWE) +from videox_fun.pipeline import LTX2Pipeline +from videox_fun.utils import (register_auto_device_hook, + safe_enable_group_offload) +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, + save_videos_grid, + save_videos_with_audio_grid) + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_group_offload transfers internal layer groups between CPU/CUDA, +# balancing memory efficiency and speed between full-module and leaf-level offloading methods. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_group_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/LTX-2.3-Diffusers" +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [512, 768] +video_length = 121 +fps = 24 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +prompt = "A brown dog barks on a sofa, sitting on a light-colored couch in a cozy room. Behind the dog, there is a framed painting on a shelf, surrounded by pink flowers. " +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts" +# CFG guidance scale for video and audio modality +guidance_scale = 3.0 +audio_guidance_scale = 7.0 +# Spatio-Temporal Guidance (STG) scale for video and audio +stg_scale = 1.0 +audio_stg_scale = 1.0 +# Modality isolation guidance scale for video and audio +modality_scale = 3.0 +audio_modality_scale = 3.0 +# Guidance rescale factor for video and audio to prevent overexposure +guidance_rescale = 0.7 +audio_guidance_rescale = 0.7 +spatio_temporal_guidance_blocks = [28] +seed = 43 +num_inference_steps = 50 +lora_weight = 0.55 +save_path = "samples/ltx2-videos-t2v" + +# Audio sample rate will be read from vocoder config +audio_sample_rate = 24000 + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) + +# Transformer +transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Video VAE +vae = AutoencoderKLLTX2Video.from_pretrained( + model_name, + subfolder="vae", + torch_dtype=weight_dtype, +) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Audio VAE +audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model_name, + subfolder="audio_vae", + torch_dtype=weight_dtype, +) + +# Get Processor +processor = Gemma3Processor.from_pretrained( + model_name, + subfolder="processor", +) + +# Get Tokenizer +tokenizer = processor.tokenizer + +# Get Text encoder +text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_name, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +text_encoder = text_encoder.eval() + +# Connectors +connectors = LTX2TextConnectors.from_pretrained( + model_name, + subfolder="connectors", + torch_dtype=weight_dtype, +) + +# Vocoder +vocoder = LTX2VocoderWithBWE.from_pretrained( + model_name, + subfolder="vocoder", + torch_dtype=weight_dtype, +) + +# Get Scheduler +Chosen_Scheduler = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = LTX2Pipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, +) + +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=list(transformer.transformer_blocks)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["scale_shift_table", "audio_scale_shift_table", "video_a2v_cross_attn_scale_shift_table", "audio_a2v_cross_attn_scale_shift_table", ""], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["scale_shift_table", "audio_scale_shift_table", "video_a2v_cross_attn_scale_shift_table", "audio_a2v_cross_attn_scale_shift_table", ""], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=sample_size[0], + width=sample_size[1], + num_frames=video_length, + frame_rate=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + stg_scale=stg_scale, + modality_scale=modality_scale, + guidance_rescale=guidance_rescale, + audio_guidance_scale=audio_guidance_scale, + audio_stg_scale=audio_stg_scale, + audio_modality_scale=audio_modality_scale, + audio_guidance_rescale=audio_guidance_rescale, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + generator=generator, + output_type="pt", + ) + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +sample = output.videos +audio = output.audio + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + if video_length == 1: + video_path = os.path.join(save_path, prefix + ".png") + + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(video_path) + else: + video_path = os.path.join(save_path, prefix + ".mp4") + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) + save_videos_with_audio_grid(sample, audio, video_path, fps=fps, audio_sample_rate=sr) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/examples/ltx2/predict_i2v.py b/examples/ltx2/predict_i2v.py index 49066fbf..ac1fc215 100644 --- a/examples/ltx2/predict_i2v.py +++ b/examples/ltx2/predict_i2v.py @@ -18,6 +18,7 @@ from videox_fun.pipeline import LTX2I2VPipeline from videox_fun.utils import (register_auto_device_hook, safe_enable_group_offload) +from videox_fun.dist import set_multi_gpus_devices, shard_model from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, @@ -45,6 +46,15 @@ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, # resulting in slower speeds but saving a large amount of GPU memory. GPU_memory_mode = "sequential_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False # Compile will give a speedup in fixed resolution and need a little GPU memory. # The compile_dit is not compatible with sequential_cpu_offload. compile_dit = False @@ -82,7 +92,7 @@ # Audio sample rate will be read from vocoder config audio_sample_rate = 24000 -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +device = set_multi_gpus_devices(ulysses_degree, ring_degree) # Transformer transformer = LTX2VideoTransformer3DModel.from_pretrained( @@ -181,6 +191,20 @@ vocoder=vocoder, ) +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=list(transformer.transformer_blocks)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + if compile_dit: for i in range(len(pipeline.transformer.transformer_blocks)): pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) @@ -249,4 +273,9 @@ def save_results(): sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) save_videos_with_audio_grid(sample, audio, video_path, fps=fps, audio_sample_rate=sr) -save_results() \ No newline at end of file +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/examples/ltx2/predict_t2v.py b/examples/ltx2/predict_t2v.py index 9b7a1d06..03c8ee82 100644 --- a/examples/ltx2/predict_t2v.py +++ b/examples/ltx2/predict_t2v.py @@ -18,6 +18,7 @@ from videox_fun.pipeline import LTX2Pipeline from videox_fun.utils import (register_auto_device_hook, safe_enable_group_offload) +from videox_fun.dist import set_multi_gpus_devices, shard_model from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, @@ -45,6 +46,15 @@ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, # resulting in slower speeds but saving a large amount of GPU memory. GPU_memory_mode = "sequential_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False # Compile will give a speedup in fixed resolution and need a little GPU memory. # The compile_dit is not compatible with sequential_cpu_offload. compile_dit = False @@ -78,7 +88,7 @@ # Audio sample rate will be read from vocoder config audio_sample_rate = 24000 -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +device = set_multi_gpus_devices(ulysses_degree, ring_degree) # Transformer transformer = LTX2VideoTransformer3DModel.from_pretrained( @@ -177,6 +187,20 @@ vocoder=vocoder, ) +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=list(transformer.transformer_blocks)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, + module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + if compile_dit: for i in range(len(pipeline.transformer.transformer_blocks)): pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) @@ -244,4 +268,9 @@ def save_results(): sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) save_videos_with_audio_grid(sample, audio, video_path, fps=fps, audio_sample_rate=sr) -save_results() \ No newline at end of file +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/scripts/flux/README_TRAIN.md b/scripts/flux/README_TRAIN.md index 3308bff7..69d4fd1d 100755 --- a/scripts/flux/README_TRAIN.md +++ b/scripts/flux/README_TRAIN.md @@ -1,28 +1,242 @@ -## Training Code +# FLUX.1 Full Parameter Training Guide -We can choose whether to use deepspeed or fsdp in flux, which can save a lot of video memory. +This document provides a complete workflow for full parameter training of FLUX.1 Diffusion Transformer, including environment configuration, data preparation, distributed training, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Full Parameter Training](#3-full-parameter-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameters](#41-inference-parameters) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json ``` -Without deepspeed: +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. Full Parameter Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download FLUX.1 official weights +modelscope download --model black-forest-labs/FLUX.1-dev --local_dir models/Diffusion_Transformer/FLUX.1-dev +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 Common Training Parameters + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Pretrained model path | `models/Diffusion_Transformer/FLUX.1-dev` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Batch size per device | 1 | +| `--image_sample_size` | Maximum training resolution (auto bucketing) | 1024 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (effective batch size increase) | 1 | +| `--dataloader_num_workers` | DataLoader subprocess count | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate | 2e-05 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir_flux` | +| `--gradient_checkpointing` | Enable gradient checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon value | 1e-10 | +| `--vae_mini_batch` | Mini batch size for VAE encoding | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training (no center crop, train full images grouped by resolution) | - | +| `--random_hw_adapt` | Auto-scale images to random sizes in `[512, image_sample_size]` range | - | +| `--resume_from_checkpoint` | Resume training path, use `"latest"` to auto-select latest checkpoint | None | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | + + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. -Training flux without DeepSpeed may result in insufficient GPU memory. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -32,7 +246,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -46,7 +260,7 @@ accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -58,8 +272,21 @@ accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ --trainable_modules "." ``` -With Deepspeed Zero-2: +### 3.5 Other Backends +#### 3.5.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not currently recommended. In this repository, FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: + +After training, you can use the following command to get the final model: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Execution command: ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -69,7 +296,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -83,7 +310,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -95,7 +322,9 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --trainable_modules "." ``` -With FSDP: +#### 3.5.2 Training without DeepSpeed or FSDP + +**This approach is not recommended as there is no memory-saving backend, which may cause insufficient VRAM.** This is provided for reference only. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" @@ -106,7 +335,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train.py \ +accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -120,7 +349,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -130,4 +359,189 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --enable_bucket \ --uniform_sampling \ --trainable_modules "." -``` \ No newline at end of file +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assume 2 machines, each with 8 GPUs: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must have access to the same data paths (NFS/shared storage) + +## 4. Inference Testing + +### 4.1 Inference Parameters + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory management mode, see table below | `model_cpu_offload_and_qfloat8` | +| `ulysses_degree` | Head dimension parallelism degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/FLUX.1-dev` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 1.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/flux-t2i` | + +**GPU Memory Management Modes**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Layer-by-layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +Run single GPU inference with the following command: + +```bash +python examples/flux/predict_t2i.py +``` + +Edit `examples/flux/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the inference parameter section above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_cpu_offload_and_qfloat8" +# Based on actual model path +model_name = "models/Diffusion_Transformer/FLUX.1-dev" +# Trained weights path, e.g., "output_dir_flux/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on content to generate +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, faster inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/flux/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelism +ring_degree = 1 # Sequence dimension parallelism +``` + +**Configuration Principles**: +- `ulysses_degree` must divide the model's head count evenly. +- `ring_degree` splits on sequence dimension, affecting communication overhead. Avoid using it when head count can be divided. + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelism | +| 8 | 8 | 1 | Head parallelism | +| 8 | 4 | 2 | Hybrid parallelism | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/flux/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/flux/README_TRAIN_LORA.md b/scripts/flux/README_TRAIN_LORA.md index 75f98a41..d171cef3 100755 --- a/scripts/flux/README_TRAIN_LORA.md +++ b/scripts/flux/README_TRAIN_LORA.md @@ -1,32 +1,242 @@ -## Lora Training Code +# FLUX.1 LoRA Fine-Tuning Training Guide -We can choose whether to use deepspeed or fsdp in flux, which can save a lot of video memory. +This document provides a complete workflow for FLUX.1 LoRA fine-tuning training, including environment configuration, data preparation, multiple distributed training strategies, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. -- `target_name` represents the components/modules to which LoRA will be applied, separated by commas. -- `use_peft_lora` indicates whether to use the PEFT module for adding LoRA. Using this module will be more memory-efficient. -- `rank` means the dimension of the LoRA update matrices. -- `network_alpha` means the scale of the LoRA update matrices. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. LoRA Training](#3-lora-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 LoRA-Specific Parameters](#33-lora-specific-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameter Parsing](#41-inference-parameter-parsing) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] ``` -Without deepspeed: +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. LoRA Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download FLUX.1 official weights +modelscope download --model black-forest-labs/FLUX.1-dev --local_dir models/Diffusion_Transformer/FLUX.1-dev +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA-Specific Parameters + +**LoRA Key Parameters Description**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Pretrained model path | `models/Diffusion_Transformer/FLUX.1-dev` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Batch size per device | 1 | +| `--image_sample_size` | Maximum training resolution (auto bucketing) | 1024 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (effective batch size increase) | 1 | +| `--dataloader_num_workers` | DataLoader subprocess count | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `learning_rate` | Initial learning rate (recommended for LoRA) | 1e-04 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed (reproducible training) | 42 | +| `--output_dir` | Output directory | `output_dir_flux_lora` | +| `--gradient_checkpointing` | Enable gradient checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--enable_bucket` | Enable bucket training (no center crop, train full images grouped by resolution) | - | +| `--uniform_sampling` | Uniform timestep sampling (recommended) | - | +| `--resume_from_checkpoint` | Resume training path, use `"latest"` to auto-select latest checkpoint | None | +| `--rank` | LoRA update matrix dimension (higher rank = more expressive but more VRAM) | 64 | +| `--network_alpha` | LoRA update matrix scaling coefficient (typically half of rank or same) | 32 | +| `--target_name` | Components/modules to apply LoRA, comma-separated | `to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2` | +| `--use_peft_lora` | Use PEFT module to add LoRA (more memory efficient) | - | + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +> ✅ **Recommended**: FSDP has been thoroughly tested in this repository with fewer errors and more stability. -Training flux without DeepSpeed may result in insufficient GPU memory. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -36,7 +246,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -48,7 +258,7 @@ accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -63,8 +273,21 @@ accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ --uniform_sampling ``` -With Deepspeed Zero-2: +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 +DeepSpeed Zero-3 is not currently recommended. In this repository, FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: + +After training, you can use the following command to get the final model: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Execution command: ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -74,7 +297,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -86,7 +309,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -101,7 +324,9 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --uniform_sampling ``` -With FSDP: +#### 3.5.2 Training without DeepSpeed or FSDP + +**This approach is not recommended as there is no memory-saving backend, which may cause insufficient VRAM.** This is provided for reference only. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" @@ -112,7 +337,57 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train_lora.py \ +accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assume 2 machines, each with 8 GPUs: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -124,7 +399,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -137,4 +412,144 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ --use_peft_lora \ --uniform_sampling -``` \ No newline at end of file +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must have access to the same data paths (NFS/shared storage) + +--- + +## 4. Inference Testing + +### 4.1 Inference Parameter Parsing + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory management mode, see table below | `model_cpu_offload_and_qfloat8` | +| `ulysses_degree` | Head dimension parallelism degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/FLUX.1-dev` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 1.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/flux-t2i` | + +**GPU Memory Management Modes**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Layer-by-layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +Run single GPU inference with the following command: + +```bash +python examples/flux/predict_t2i.py +``` + +Edit `examples/flux/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the inference parameter section above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_cpu_offload_and_qfloat8" +# Based on actual model path +model_name = "models/Diffusion_Transformer/FLUX.1-dev" +# LoRA weights path, e.g., "output_dir_flux_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA weight strength +lora_weight = 0.55 +# Write based on content to generate +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, faster inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/flux/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelism +ring_degree = 1 # Sequence dimension parallelism +``` + +**Configuration Principles**: +- `ulysses_degree` must divide the model's head count evenly. +- `ring_degree` splits on sequence dimension, affecting communication overhead. Avoid using it when head count can be divided. + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelism | +| 8 | 8 | 1 | Head parallelism | +| 8 | 4 | 2 | Hybrid parallelism | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/flux/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/flux/README_TRAIN_LORA_zh-CN.md b/scripts/flux/README_TRAIN_LORA_zh-CN.md new file mode 100755 index 00000000..2a2e69fb --- /dev/null +++ b/scripts/flux/README_TRAIN_LORA_zh-CN.md @@ -0,0 +1,555 @@ +# FLUX.1 LoRA 微调训练指南 + +本文档提供 FLUX.1 LoRA 微调训练的完整流程,包括环境配置、数据准备、多种分布式训练策略和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、LoRA 训练](#三lora-训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 LoRA 专用参数解析](#33-lora-专用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1024, + "height": 1024 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**建议**提供以支持 bucket 训练;若不提供,训练时会自动读取,但在 OSS 等较慢系统中可能拖慢训练速度) + - 可使用 `scripts/process_json_add_width_and_height.py` 为没有宽高字段的 JSON 文件添加,支持图片和视频 + - 用法:`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 相对路径与绝对路径使用方案 + +**使用相对路径**: + +如果你的数据使用相对路径,训练脚本中这样配置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**使用绝对路径**: + +如果你的数据使用绝对路径,训练脚本中这样配置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,使用相对路径。如果数据集存储在外部存储(如 NAS、OSS)或多机共享,使用绝对路径。 + +--- + +## 三、LoRA 训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 FLUX.1 官方权重 +modelscope download --model black-forest-labs/FLUX.1-dev --local_dir models/Diffusion_Transformer/FLUX.1-dev +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果你已经按照 **2.1 快速测试数据集** 下载了数据,按照 **3.1 下载预训练模型** 下载了权重,则可以直接复制运行快速开始的命令。 + +训练推荐使用 DeepSpeed-Zero-2 或 FSDP。这里以 DeepSpeed-Zero-2 为例。 + +DeepSpeed-Zero-2 和 FSDP 的区别在于是否对模型权重进行分片。**如果多卡使用 DeepSpeed-Zero-2 时显存不足**,可以切换为 FSDP。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA 专用参数解析 + +**LoRA 关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/FLUX.1-dev` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每张卡的批次大小 | 1 | +| `--image_sample_size` | 最大训练分辨率(自动分桶) | 1024 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch size) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存检查点 | 50 | +| `--learning_rate` | 初始学习率(LoRA 推荐值) | 1e-04 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子(可复现训练) | 42 | +| `--output_dir` | 输出目录 | `output_dir_flux_lora` | +| `--gradient_checkpointing` | 启用梯度检查点 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--enable_bucket` | 启用桶训练(不中心裁剪,按分辨率分组后训练完整图像) | - | +| `--uniform_sampling` | 均匀时间步采样(推荐) | - | +| `--resume_from_checkpoint` | 恢复训练的路径,使用 `"latest"` 自动选择最新检查点 | None | +| `--rank` | LoRA 更新矩阵维度(rank 越高表达能力越强但显存占用越大) | 64 | +| `--network_alpha` | LoRA 更新矩阵缩放系数(通常为 rank 的一半或相同) | 32 | +| `--target_name` | 应用 LoRA 的组件/模块,用逗号分隔 | `to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2` | +| `--use_peft_lora` | 使用 PEFT 模块添加 LoRA(更节省显存) | - | + +### 3.4 使用 FSDP 训练 + +**如果多卡使用 DeepSpeed-Zero-2 时显存不足**,可以切换为 FSDP。 + +> ✅ **推荐**:FSDP 在本仓库中经过充分测试,错误更少且更稳定。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.5 其他后端 + +#### 3.5.1 使用 DeepSpeed-Zero-3 训练 + +当前不推荐使用 DeepSpeed Zero-3。在本仓库中,FSDP 错误更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{your-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +#### 3.5.2 不使用 DeepSpeed 或 FSDP 训练 + +**不推荐此方法,因为没有显存优化的后端,可能导致显存不足**。仅供参考。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 多机分布式训练 + +**适用场景**:大规模数据集,更快的训练速度 + +#### 3.6.1 环境配置 + +假设 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 总机器数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐使用 RDMA/InfiniBand(高性能) + - 无 RDMA 时,添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能访问相同的数据路径(NFS/共享存储) + +--- + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | GPU 显存管理模式,见下表 | `model_cpu_offload_and_qfloat8` | +| `ulysses_degree` | 头维度并行度,单卡为 1 | 1 | +| `ring_degree` | 序列维度并行度,单卡为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 以节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 以加速推理(固定分辨率时有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/FLUX.1-dev` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 训练后的 Transformer 权重路径 | `None` | +| `vae_path` | 训练后的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[height, width]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述要生成的内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 反向提示词,描述要避免的内容 | `"The video is not of a high quality..."` | +| `guidance_scale` | 引导强度 | 1.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/flux-t2i` | + +**GPU 显存管理模式**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 加载整个模型到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 完整加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后卸载模型到 CPU | 中 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(最慢) | 最低 | + +### 4.2 单卡推理 + +使用以下命令运行单卡推理: + +```bash +python examples/flux/predict_t2i.py +``` + +根据需要编辑 `examples/flux/predict_t2i.py`。首次推理时重点关注以下参数,其他参数见上方推理参数解析。 + +```python +# 根据 GPU 显存选择 +GPU_memory_mode = "model_cpu_offload_and_qfloat8" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/FLUX.1-dev" +# LoRA 权重路径,例如 "output_dir_flux_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA 权重强度 +lora_weight = 0.55 +# 根据要生成的内容编写 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 多卡并行推理 + +**适用场景**:高分辨率生成,更快的推理速度 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/flux/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = 使用的 GPU 数 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # 头维度并行 +ring_degree = 1 # 序列维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的头数。 +- `ring_degree` 在序列维度切分,影响通信开销。头数能整除时尽量避免使用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | 头并行 | +| 8 | 8 | 1 | 头并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/flux/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun \ No newline at end of file diff --git a/scripts/flux/README_TRAIN_zh-CN.md b/scripts/flux/README_TRAIN_zh-CN.md new file mode 100755 index 00000000..dabee566 --- /dev/null +++ b/scripts/flux/README_TRAIN_zh-CN.md @@ -0,0 +1,547 @@ +# FLUX.1 全量参数训练指南 + +本文档提供 FLUX.1 Diffusion Transformer 全量参数训练的完整流程,包括环境配置、数据准备、分布式训练和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、全量参数训练](#三全量参数训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1024, + "height": 1024 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**建议**提供以支持 bucket 训练;若不提供,训练时会自动读取,但在 OSS 等较慢系统中可能拖慢训练速度) + - 可使用 `scripts/process_json_add_width_and_height.py` 为没有宽高字段的 JSON 文件添加,支持图片和视频 + - 用法:`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 相对路径与绝对路径使用方案 + +**使用相对路径**: + +如果你的数据使用相对路径,训练脚本中这样配置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**使用绝对路径**: + +如果你的数据使用绝对路径,训练脚本中这样配置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,使用相对路径。如果数据集存储在外部存储(如 NAS、OSS)或多机共享,使用绝对路径。 + +--- + +## 三、全量参数训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 FLUX.1 官方权重 +modelscope download --model black-forest-labs/FLUX.1-dev --local_dir models/Diffusion_Transformer/FLUX.1-dev +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果你已经按照 **2.1 快速测试数据集** 下载了数据,按照 **3.1 下载预训练模型** 下载了权重,则可以直接复制运行快速开始的命令。 + +训练推荐使用 DeepSpeed-Zero-2 或 FSDP。这里以 DeepSpeed-Zero-2 为例。 + +DeepSpeed-Zero-2 和 FSDP 的区别在于是否对模型权重进行分片。**如果多卡使用 DeepSpeed-Zero-2 时显存不足**,可以切换为 FSDP。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/FLUX.1-dev` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每张卡的批次大小 | 1 | +| `--image_sample_size` | 最大训练分辨率(自动分桶) | 1024 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch size) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存检查点 | 50 | +| `--learning_rate` | 初始学习率 | 2e-05 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir_flux` | +| `--gradient_checkpointing` | 启用梯度检查点 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码的 mini batch 大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用桶训练(不中心裁剪,按分辨率分组后训练完整图像) | - | +| `--random_hw_adapt` | 自动将图像缩放到 `[512, image_sample_size]` 范围内的随机尺寸 | - | +| `--resume_from_checkpoint` | 恢复训练的路径,使用 `"latest"` 自动选择最新检查点 | None | +| `--uniform_sampling` | 均匀时间步采样 | - | +| `--trainable_modules` | 可训练模块(`.` 表示所有模块) | `"."` | + + +### 3.4 使用 FSDP 训练 + +**如果多卡使用 DeepSpeed-Zero-2 时显存不足**,可以切换为 FSDP。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap FluxSingleTransformerBlock,FluxTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.5 其他后端 + +#### 3.5.1 使用 DeepSpeed-Zero-3 训练 + +当前不推荐使用 DeepSpeed Zero-3。在本仓库中,FSDP 错误更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{your-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +#### 3.5.2 不使用 DeepSpeed 或 FSDP 训练 + +**不推荐此方法,因为没有显存优化的后端,可能导致显存不足**。仅供参考。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 多机分布式训练 + +**适用场景**:大规模数据集,更快的训练速度 + +#### 3.6.1 环境配置 + +假设 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 总机器数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1024 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 和 NCCL_P2P_DISABLE=1 用于无 RDMA 的多机环境 +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐使用 RDMA/InfiniBand(高性能) + - 无 RDMA 时,添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能访问相同的数据路径(NFS/共享存储) + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | GPU 显存管理模式,见下表 | `model_cpu_offload_and_qfloat8` | +| `ulysses_degree` | 头维度并行度,单卡为 1 | 1 | +| `ring_degree` | 序列维度并行度,单卡为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 以节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 以加速推理(固定分辨率时有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/FLUX.1-dev` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 训练后的 Transformer 权重路径 | `None` | +| `vae_path` | 训练后的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[height, width]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述要生成的内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 反向提示词,描述要避免的内容 | `"The video is not of a high quality..."` | +| `guidance_scale` | 引导强度 | 1.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/flux-t2i` | + +**GPU 显存管理模式**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 加载整个模型到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 完整加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后卸载模型到 CPU | 中 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(最慢) | 最低 | + +### 4.2 单卡推理 + +使用以下命令运行单卡推理: + +```bash +python examples/flux/predict_t2i.py +``` + +根据需要编辑 `examples/flux/predict_t2i.py`。首次推理时重点关注以下参数,其他参数见上方推理参数解析。 + +```python +# 根据 GPU 显存选择 +GPU_memory_mode = "model_cpu_offload_and_qfloat8" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/FLUX.1-dev" +# 训练后的权重路径,例如 "output_dir_flux/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据要生成的内容编写 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 多卡并行推理 + +**适用场景**:高分辨率生成,更快的推理速度 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/flux/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = 使用的 GPU 数 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # 头维度并行 +ring_degree = 1 # 序列维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的头数。 +- `ring_degree` 在序列维度切分,影响通信开销。头数能整除时尽量避免使用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | 头并行 | +| 8 | 8 | 1 | 头并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/flux/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun \ No newline at end of file diff --git a/scripts/flux2/README_TRAIN.md b/scripts/flux2/README_TRAIN.md index c6837128..364786c0 100644 --- a/scripts/flux2/README_TRAIN.md +++ b/scripts/flux2/README_TRAIN.md @@ -1,30 +1,244 @@ -## Training Code +# FLUX.2 Full Parameter Training Guide -We can choose whether to use deepspeed or fsdp in flux2, which can save a lot of video memory. +This document provides a complete workflow for full parameter training of FLUX.2 Diffusion Transformer, including environment configuration, data preparation, distributed training, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Full Parameter Training](#3-full-parameter-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameters](#41-inference-parameters) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +## 3. Full Parameter Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download FLUX.2 official weights +modelscope download --model black-forest-labs/FLUX.2-dev --local_dir models/Diffusion_Transformer/FLUX.2-dev +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." ``` -Without deepspeed: +### 3.3 Common Training Parameters + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Pretrained model path | `models/Diffusion_Transformer/FLUX.2-dev` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Batch size per device | 1 | +| `--image_sample_size` | Maximum training resolution (auto bucketing) | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (effective batch size increase) | 1 | +| `--dataloader_num_workers` | DataLoader subprocess count | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate | 2e-05 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir_flux2` | +| `--gradient_checkpointing` | Enable gradient checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon value | 1e-10 | +| `--vae_mini_batch` | Mini batch size for VAE encoding | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training (no center crop, train full images grouped by resolution) | - | +| `--random_hw_adapt` | Auto-scale images to random sizes in `[512, image_sample_size]` range | - | +| `--resume_from_checkpoint` | Resume training path, use `"latest"` to auto-select latest checkpoint | None | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | + + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. -Training flux2 without DeepSpeed may result in insufficient GPU memory. ```sh -export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -32,7 +246,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/flux2/train.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -46,7 +260,7 @@ accelerate launch --mixed_precision="bf16" scripts/flux2/train.py \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -58,10 +272,23 @@ accelerate launch --mixed_precision="bf16" scripts/flux2/train.py \ --trainable_modules "." ``` -With Deepspeed Zero-2: +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not currently recommended. In this repository, FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: + +After training, you can use the following command to get the final model: ```sh -export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Execution command: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -69,7 +296,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -83,7 +310,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -95,10 +322,12 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --trainable_modules "." ``` -With FSDP: +#### 3.5.2 Training without DeepSpeed or FSDP + +**This approach is not recommended as there is no memory-saving backend, which may cause insufficient VRAM.** This is provided for reference only. ```sh -export MODEL_NAME="models/Diffusion_Transformer/FLUX.1-dev" +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -106,7 +335,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train.py \ +accelerate launch --mixed_precision="bf16" scripts/flux2/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -120,7 +349,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -130,4 +359,189 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --enable_bucket \ --uniform_sampling \ --trainable_modules "." -``` \ No newline at end of file +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assume 2 machines, each with 8 GPUs: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must have access to the same data paths (NFS/shared storage) + +## 4. Inference Testing + +### 4.1 Inference Parameters + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory management mode, see table below | `model_cpu_offload` | +| `ulysses_degree` | Head dimension parallelism degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/FLUX.2-dev` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 4.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/flux2-t2i` | + +**GPU Memory Management Modes**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Layer-by-layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +Run single GPU inference with the following command: + +```bash +python examples/flux2/predict_t2i.py +``` + +Edit `examples/flux2/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the inference parameter section above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "sequential_cpu_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/FLUX.2-dev" +# Trained weights path, e.g., "output_dir_flux2/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on content to generate +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, faster inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/flux2/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelism +ring_degree = 1 # Sequence dimension parallelism +``` + +**Configuration Principles**: +- `ulysses_degree` must divide the model's head count evenly. +- `ring_degree` splits on sequence dimension, affecting communication overhead. Avoid using it when head count can be divided. + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelism | +| 8 | 8 | 1 | Head parallelism | +| 8 | 4 | 2 | Hybrid parallelism | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/flux2/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/flux2/README_TRAIN_LORA.md b/scripts/flux2/README_TRAIN_LORA.md index a577ed82..c8455f2f 100644 --- a/scripts/flux2/README_TRAIN_LORA.md +++ b/scripts/flux2/README_TRAIN_LORA.md @@ -1,32 +1,241 @@ -## Lora Training Code +# FLUX.2 LoRA Fine-Tuning Training Guide -We can choose whether to use deepspeed or fsdp in flux2, which can save a lot of video memory. +This document provides a complete workflow for FLUX.2 LoRA fine-tuning training, including environment configuration, data preparation, multiple distributed training strategies, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. -- `target_name` represents the components/modules to which LoRA will be applied, separated by commas. -- `use_peft_lora` indicates whether to use the PEFT module for adding LoRA. Using this module will be more memory-efficient. -- `rank` means the dimension of the LoRA update matrices. -- `network_alpha` means the scale of the LoRA update matrices. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. LoRA Training](#3-lora-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 LoRA-Specific Parameters](#33-lora-specific-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameter Parsing](#41-inference-parameter-parsing) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] ``` -Without deepspeed: +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. LoRA Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download FLUX.2 official weights +modelscope download --model black-forest-labs/FLUX.2-dev --local_dir models/Diffusion_Transformer/FLUX.2-dev +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA-Specific Parameters + +**LoRA Key Parameters Description**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Pretrained model path | `models/Diffusion_Transformer/FLUX.2-dev` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Batch size per device | 1 | +| `--image_sample_size` | Maximum training resolution (auto bucketing) | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (effective batch size increase) | 1 | +| `--dataloader_num_workers` | DataLoader subprocess count | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate (recommended for LoRA) | 1e-04 | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed (reproducible training) | 42 | +| `--output_dir` | Output directory | `output_dir_flux2_lora` | +| `--gradient_checkpointing` | Enable gradient checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--enable_bucket` | Enable bucket training (no center crop, train full images grouped by resolution) | - | +| `--uniform_sampling` | Uniform timestep sampling (recommended) | - | +| `--resume_from_checkpoint` | Resume training path, use `"latest"` to auto-select latest checkpoint | None | +| `--rank` | LoRA update matrix dimension (higher rank = more expressive but more VRAM) | 64 | +| `--network_alpha` | LoRA update matrix scaling coefficient (typically half of rank or same) | 32 | +| `--target_name` | Components/modules to apply LoRA, comma-separated | `to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2` | +| `--use_peft_lora` | Use PEFT module to add LoRA (more memory efficient) | - | + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +> ✅ **Recommended**: FSDP has been thoroughly tested in this repository with fewer errors and more stability. -Training flux2 without DeepSpeed may result in insufficient GPU memory. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -36,7 +245,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/flux2/train_lora.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -48,7 +257,7 @@ accelerate launch --mixed_precision="bf16" scripts/flux2/train_lora.py \ --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -63,8 +272,21 @@ accelerate launch --mixed_precision="bf16" scripts/flux2/train_lora.py \ --uniform_sampling ``` -With Deepspeed Zero-2: +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 +DeepSpeed Zero-3 is not currently recommended. In this repository, FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: + +After training, you can use the following command to get the final model: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Execution command: ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" export DATASET_NAME="datasets/internal_datasets/" @@ -74,7 +296,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -86,7 +308,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -101,7 +323,9 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --uniform_sampling ``` -With FSDP: +#### 3.5.2 Training without DeepSpeed or FSDP + +**This approach is not recommended as there is no memory-saving backend, which may cause insufficient VRAM.** This is provided for reference only. ```sh export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" @@ -112,7 +336,57 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train_lora.py \ +accelerate launch --mixed_precision="bf16" scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assume 2 machines, each with 8 GPUs: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -124,7 +398,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_flux2_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -137,4 +411,144 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ --use_peft_lora \ --uniform_sampling -``` \ No newline at end of file +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must have access to the same data paths (NFS/shared storage) + +--- + +## 4. Inference Testing + +### 4.1 Inference Parameter Parsing + +**Key Parameters Description**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory management mode, see table below | `model_cpu_offload` | +| `ulysses_degree` | Head dimension parallelism degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/FLUX.2-dev` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 4.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/flux2-t2i` | + +**GPU Memory Management Modes**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Layer-by-layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +Run single GPU inference with the following command: + +```bash +python examples/flux2/predict_t2i.py +``` + +Edit `examples/flux2/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the inference parameter section above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "sequential_cpu_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/FLUX.2-dev" +# LoRA weights path, e.g., "output_dir_flux2_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA weight strength +lora_weight = 0.55 +# Write based on content to generate +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, faster inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/flux2/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelism +ring_degree = 1 # Sequence dimension parallelism +``` + +**Configuration Principles**: +- `ulysses_degree` must divide the model's head count evenly. +- `ring_degree` splits on sequence dimension, affecting communication overhead. Avoid using it when head count can be divided. + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelism | +| 8 | 8 | 1 | Head parallelism | +| 8 | 4 | 2 | Hybrid parallelism | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/flux2/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/flux2/README_TRAIN_LORA_zh-CN.md b/scripts/flux2/README_TRAIN_LORA_zh-CN.md new file mode 100644 index 00000000..f625a195 --- /dev/null +++ b/scripts/flux2/README_TRAIN_LORA_zh-CN.md @@ -0,0 +1,554 @@ +# FLUX.2 LoRA 微调训练指南 + +本文档提供 FLUX.2 LoRA 微调训练的完整流程,包括环境配置、数据准备、多种分布式训练策略和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、LoRA 训练](#三lora-训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 LoRA 专用参数解析](#33-lora-专用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、LoRA 训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 FLUX.2 官方权重 +modelscope download --model black-forest-labs/FLUX.2-dev --local_dir models/Diffusion_Transformer/FLUX.2-dev +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用 DeepSpeed-Zero-2 与 FSDP 方案进行训练。这里使用 DeepSpeed-Zero-2 为例配置 shell 文件。 + +本文中 DeepSpeed-Zero-2 与 FSDP 的差别在于是否对模型权重进行分片,**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA 专用参数解析 + +**LoRA 关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/FLUX.2-dev` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率(LoRA 推荐值) | 1e-04 | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子(可复现训练) | 42 | +| `--output_dir` | 输出目录 | `output_dir_flux2_lora` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--uniform_sampling` | 均匀采样 timestep(推荐启用) | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--rank` | LoRA 更新矩阵的维度(rank 越大表达能力越强,但显存占用越高) | 64 | +| `--network_alpha` | LoRA 更新矩阵的缩放系数(通常设置为 rank 的一半或相同) | 32 | +| `--target_name` | 应用 LoRA 的组件/模块,用逗号分隔 | `to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2` | +| `--use_peft_lora` | 使用 PEFT 模块添加 LoRA(更节省显存) | - | + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +> ✅ **推荐**:FSDP 在当前仓库中经过充分测试,错误更少、更稳定。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练 Shell 用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_flux2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +--- + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/FLUX.2-dev` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `" "` | +| `guidance_scale` | 引导强度 | 4.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/flux2-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +单卡推理运行如下命令: + +```bash +python examples/flux2/predict_t2i.py +``` + +根据需求修改编辑 `examples/flux2/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "sequential_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/FLUX.2-dev" +# LoRA 权重路径,如 "output_dir_flux2_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA 权重强度 +lora_weight = 0.55 +# 根据生成内容编写 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/flux2/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/flux2/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/flux2/README_TRAIN_zh-CN.md b/scripts/flux2/README_TRAIN_zh-CN.md new file mode 100644 index 00000000..95dfa80f --- /dev/null +++ b/scripts/flux2/README_TRAIN_zh-CN.md @@ -0,0 +1,547 @@ +# FLUX.2 全量参数训练指南 + +本文档提供 FLUX.2 Diffusion Transformer 全量参数训练的完整流程,包括环境配置、数据准备、分布式训练和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、全量参数训练](#三全量参数训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、全量参数训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 FLUX.2 官方权重 +modelscope download --model black-forest-labs/FLUX.2-dev --local_dir models/Diffusion_Transformer/FLUX.2-dev +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用DeepSpeed-Zero-2与FSDP方案进行训练。这里使用DeepSpeed-Zero-2为例配置shell文件。 + +本文中DeepSpeed-Zero-2与FSDP的差别在于是否对模型权重进行分片,**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/FLUX.2-dev` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率 | 2e-05 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir_flux2` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码时的迷你批次大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--random_hw_adapt` | 自动缩放图片到 `[512, image_sample_size]` 范围内的随机尺寸 | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--uniform_sampling` | 均匀采样 timestep | - | +| `--trainable_modules` | 可训练模块(`"."` 表示所有模块) | `"."` | + + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Flux2SingleTransformerBlock,Flux2TransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练Shell用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/flux2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_flux2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/FLUX.2-dev" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/FLUX.2-dev` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `" "` | +| `guidance_scale` | 引导强度 | 4.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/flux2-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +单卡推理运行如下命令: + +```bash +python examples/flux2/predict_t2i.py +``` + +根据需求修改编辑 `examples/flux2/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "sequential_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/FLUX.2-dev" +# 训练好的权重路径,如 "output_dir_flux2/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据生成内容编写 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/flux2/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/flux2/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/ltx2.3/README_TRAIN.md b/scripts/ltx2.3/README_TRAIN.md new file mode 100644 index 00000000..fb8e1f72 --- /dev/null +++ b/scripts/ltx2.3/README_TRAIN.md @@ -0,0 +1,183 @@ +## Training Code + +The default training commands for the different versions are as follows: + +We can choose whether to use DeepSpeed and FSDP in LTX2.3, which can save a lot of video memory. + +The metadata.json is a little different from normal json in VideoX-Fun, you need to add a audio_path. + +```json +[ + { + "file_path": "train/00000001.mp4", + "audio_path": "wav/00000001.wav", + "text": "A group of young men in suits and sunglasses are walking down a city street.", + "type": "video" + }, + ..... +] +``` + +Some parameters in the sh file can be confusing, and they are explained in this document: + +- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the videos at the center, but instead, it trains the videos after grouping them into buckets based on resolution. +- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts. +- `random_hw_adapt` is used to enable automatic height and width scaling for videos. When `random_hw_adapt` is enabled, for training videos, the height and width will be set to `video_sample_size` as the maximum and `512` as the minimum. + - For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=768`, the resolution of video inputs for training is `512x512x49`, `768x768x49`. +- `training_with_video_token_length` specifies training the model according to token length. For training videos, the height and width will be set to `video_sample_size` as the maximum and `256` as the minimum. + - For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=512`, `video_sample_size=768`, the resolution of video inputs for training is `256x256x49`, `512x512x49`, `768x768x21`. + - The token length for a video with dimensions 512x512 and 49 frames is 13,312. We need to set the `token_sample_size = 512`. + - At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512). + - At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768). + - At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024). + - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes. +- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. + +When train model with multi machines, please set the params as follows: +```sh +export MASTER_ADDR="your master address" +export MASTER_PORT=10086 +export WORLD_SIZE=1 # The number of machines +export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 +export RANK=0 # The rank of this machine + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +``` + +LTX2.3 without deepspeed: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ltx2.3/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ltx2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --trainable_modules "." +``` + +LTX2.3 with Deepspeed Zero-2: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ltx2.3/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ltx2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --trainable_modules "." +``` + +LTX2.3 with FSDP: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \ + --fsdp_transformer_layer_cls_to_wrap=LTX2VideoTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" \ + --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" \ + --fsdp_cpu_ram_efficient_loading False scripts/ltx2.3/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ltx2" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --trainable_modules "." +``` \ No newline at end of file diff --git a/scripts/ltx2.3/README_TRAIN_LORA.md b/scripts/ltx2.3/README_TRAIN_LORA.md new file mode 100644 index 00000000..52f57a39 --- /dev/null +++ b/scripts/ltx2.3/README_TRAIN_LORA.md @@ -0,0 +1,190 @@ +## Lora Training Code + +The default training commands for the different versions are as follows: + +We can choose whether to use DeepSpeed and FSDP in LTX2.3, which can save a lot of video memory. + +The metadata.json is a little different from normal json in VideoX-Fun, you need to add a audio_path. + +```json +[ + { + "file_path": "train/00000001.mp4", + "audio_path": "wav/00000001.wav", + "text": "A group of young men in suits and sunglasses are walking down a city street.", + "type": "video" + }, + ..... +] +``` + +Some parameters in the sh file can be confusing, and they are explained in this document: + +- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the videos at the center, but instead, it trains the videos after grouping them into buckets based on resolution. +- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts. +- `random_hw_adapt` is used to enable automatic height and width scaling for videos. When `random_hw_adapt` is enabled, for training videos, the height and width will be set to `video_sample_size` as the maximum and `512` as the minimum. + - For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=768`, the resolution of video inputs for training is `512x512x49`, `768x768x49`. +- `training_with_video_token_length` specifies training the model according to token length. For training videos, the height and width will be set to `video_sample_size` as the maximum and `256` as the minimum. + - For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=512`, `video_sample_size=768`, the resolution of video inputs for training is `256x256x49`, `512x512x49`, `768x768x21`. + - The token length for a video with dimensions 512x512 and 49 frames is 13,312. We need to set the `token_sample_size = 512`. + - At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512). + - At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768). + - At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024). + - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes. +- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +- `target_name` represents the components/modules to which LoRA will be applied, separated by commas. +- `use_peft_lora` indicates whether to use the PEFT module for adding LoRA. Using this module will be more memory-efficient. +- `rank` means the dimension of the LoRA update matrices. +- `network_alpha` means the scale of the LoRA update matrices. + +When train model with multi machines, please set the params as follows: +```sh +export MASTER_ADDR="your master address" +export MASTER_PORT=10086 +export WORLD_SIZE=1 # The number of machines +export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 +export RANK=0 # The rank of this machine + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +``` + +LTX2.3 without deepspeed: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ltx2.3/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_ltx2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,audio_ff.0,audio_ff.2" \ + --use_peft_lora \ + --low_vram +``` + +LTX2.3 with Deepspeed Zero-2: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ltx2.3/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_ltx2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,audio_ff.0,audio_ff.2" \ + --use_peft_lora \ + --low_vram +``` + +LTX2.3 with FSDP: + +```sh +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \ + --fsdp_transformer_layer_cls_to_wrap=LTX2VideoTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" \ + --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" \ + --fsdp_cpu_ram_efficient_loading False scripts/ltx2.3/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_ltx2_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,audio_ff.0,audio_ff.2" \ + --use_peft_lora \ + --low_vram +``` \ No newline at end of file diff --git a/scripts/ltx2.3/train.py b/scripts/ltx2.3/train.py new file mode 100644 index 00000000..021e7ee9 --- /dev/null +++ b/scripts/ltx2.3/train.py @@ -0,0 +1,2143 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image_video import (ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.data.dataset_video import VideoSpeechDataset +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + Gemma3ForConditionalGeneration, Gemma3Processor, + LTX2TextConnectors, + LTX2VideoTransformer3DModel, LTX2VocoderWithBWE) +from videox_fun.pipeline import LTX2Pipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.utils import (calculate_dimensions, get_image_latent, + get_image_to_video_latent, + save_videos_grid, + save_videos_with_audio_grid) + +if is_wandb_available(): + import wandb + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +# LTX2 helper functions for packing text embeddings and latents +def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: + """Packs and normalizes text encoder hidden states, respecting padding.""" + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] + + # Compute masked mean + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to 3D tensor + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + """Packs latents [B, C, F, H, W] into token sequence [B, S, D].""" + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + """Unpacks token sequence [B, S, D] back to latents [B, C, F, H, W].""" + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + """Normalizes latents across the channel dimension [B, C, F, H, W].""" + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + +def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None +) -> torch.Tensor: + """Packs audio latents [B, C, L, M] into token sequence.""" + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + +def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + """Normalizes audio latents.""" + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, audio_vae, text_encoder, tokenizer, processor, connectors, vocoder, transformer3d, args, accelerator, weight_dtype, global_step): + try: + is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' + if is_deepspeed: + origin_config = transformer3d.config + transformer3d.config = accelerator.unwrap_model(transformer3d).config + with torch.no_grad(), torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + logger.info("Running validation... ") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + pipeline = LTX2Pipeline( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + connectors=connectors, + transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, + vocoder=vocoder, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + rank_seed = args.seed + accelerator.process_index + generator = torch.Generator(device=accelerator.device).manual_seed(rank_seed) + logger.info(f"Rank {accelerator.process_index} using seed: {rank_seed}") + + for i in range(len(args.validation_prompts)): + output = pipeline( + args.validation_prompts[i], + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts", + height = args.video_sample_size, + width = args.video_sample_size, + num_frames = args.video_sample_n_frames, + num_inference_steps = 25, + guidance_scale = 3.0, + audio_guidance_scale = 7.0, + stg_scale = 1.0, + audio_stg_scale = 1.0, + modality_scale = 3.0, + audio_modality_scale = 3.0, + guidance_rescale = 0.7, + audio_guidance_rescale = 0.7, + spatio_temporal_guidance_blocks = [28], + generator = generator, + ) + sample = output.videos + audio = output.audio + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) + save_videos_with_audio_grid( + sample, + audio, + os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + fps=24, + audio_sample_rate=sr, + ) + + del pipeline + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if is_deepspeed: + transformer3d.config = origin_config + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error on rank {accelerator.process_index} with info {e}") + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_paths", + type=str, + default=None, + nargs="+", + help=("A set of control videos evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)." + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--i2v_ratio", + type=float, + default=0.5, + help=( + 'Ratio of I2V samples in training. 0.0 = pure T2V, 1.0 = pure I2V, ' + '0.5 = 50%% T2V + 50%% I2V (default).' + ), + ) + parser.add_argument( + "--i2v_noise_scale", + type=float, + default=0.0, + help=( + 'Noise scale for I2V first frame conditioning. ' + '0.0 means first frame is kept clean (default). ' + 'Higher values add slight noise to the condition frame.' + ), + ) + parser.add_argument( + "--abnormal_norm_clip_start", + type=int, + default=1000, + help=( + 'When do we start doing additional processing on abnormal gradients. ' + ), + ) + parser.add_argument( + "--initial_grad_norm_ratio", + type=int, + default=5, + help=( + 'The initial gradient is relative to the multiple of the max_grad_norm. ' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + # Get Processor + processor = Gemma3Processor.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="processor", + ) + + # Get Tokenizer from processor + tokenizer = processor.tokenizer + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKLLTX2Video.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype, + ) + vae.eval() + audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="audio_vae", + torch_dtype=weight_dtype, + ) + audio_vae.eval() + + # Connectors + connectors = LTX2TextConnectors.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="connectors", + torch_dtype=weight_dtype, + ) + # Vocoder + vocoder = LTX2VocoderWithBWE.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vocoder", + torch_dtype=weight_dtype, + ) + + # Get Transformer + transformer3d = LTX2VideoTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + low_cpu_mem_usage=True, + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + audio_vae.requires_grad_(False) + connectors.requires_grad_(False) + vocoder.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # A good trainable modules is showed below now. + # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position'] + # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position'] + transformer3d.train() + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + + # Create EMA for the transformer3d. + if args.use_ema: + if zero_stage == 3: + raise NotImplementedError("FSDP does not support EMA.") + + ema_transformer3d = LTX2VideoTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ).to(weight_dtype) + + ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=LTX2VideoTransformer3DModel, model_config=ema_transformer3d.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0 or zero_stage == 3: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + accelerate_state_dict = {k: v.to(dtype=weight_dtype) for k, v in accelerate_state_dict.items()} + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema")) + + models[0].save_pretrained(os.path.join(output_dir, "transformer")) + if not args.use_deepspeed: + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + if args.use_ema: + ema_path = os.path.join(input_dir, "transformer_ema") + _, ema_kwargs = LTX2VideoTransformer3DModel.load_config(ema_path, return_unused_kwargs=True) + load_model = LTX2VideoTransformer3DModel.from_pretrained( + input_dir, subfolder="transformer_ema", + ) + load_model = EMAModel(load_model.parameters(), model_cls=LTX2VideoTransformer3DModel, model_config=load_model.config) + load_model.load_state_dict(ema_kwargs) + + ema_transformer3d.load_state_dict(load_model.state_dict()) + ema_transformer3d.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = LTX2VideoTransformer3DModel.from_pretrained( + input_dir, subfolder="transformer" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except Exception: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + train_dataset = VideoSpeechDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + enable_bucket=args.enable_bucket, enable_inpaint=True, audio_sr=getattr(audio_vae.config, 'sample_rate', 16000), + ) + + # Pre-create mel spectrogram transform (avoid recreating per iteration) + audio_sampling_rate = getattr(audio_vae.config, 'sample_rate', 16000) + audio_hop_length = getattr(audio_vae.config, 'mel_hop_length', 160) + audio_mel_bins = getattr(audio_vae.config, 'mel_bins', 64) + audio_in_channels = getattr(audio_vae.config, 'in_channels', 2) + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_sampling_rate, + n_fft=1024, + win_length=1024, + hop_length=audio_hop_length, + f_min=0.0, + f_max=audio_sampling_rate / 2.0, + n_mels=audio_mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale='slaney', + norm='slaney', + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_length_to_frame_num(token_length): + if args.video_sample_size > 256: + sample_sizes = list(range(256, args.video_sample_size + 1, 128)) + + if sample_sizes[-1] != args.video_sample_size: + sample_sizes.append(args.video_sample_size) + else: + sample_sizes = [args.video_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes + } + + return length_to_frame_num + + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["text"] = [] + new_examples["audio"] = [] + new_examples["fps"] = [] + + # Used in Inpaint mode + new_examples["mask_pixel_values"] = [] + new_examples["mask"] = [] + new_examples["clip_pixel_values"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + f, h, w, c = np.shape(pixel_value) + + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + # The video will be resized to a lower resolution than its own. + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + local_video_sample_size = np.random.choice(choice_list) + batch_video_length = length_to_frame_num[local_video_sample_size] + random_downsample_ratio = args.video_sample_size / local_video_sample_size + else: + random_downsample_ratio = get_random_downsample_ratio(args.video_sample_size) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 64) * 64 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 64) * 64 for x in closest_size] + + min_example_length = min( + [example["pixel_values"].shape[0] for example in examples] + ) + batch_video_length = int(min(batch_video_length, min_example_length)) + + # Magvae needs the number of frames to be 4n + 1. + batch_video_length = (batch_video_length - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + + if batch_video_length <= 0: + batch_video_length = 1 + + for example in examples: + if args.fix_sample_size is not None: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)[:batch_video_length]) + new_examples["text"].append(example["text"]) + + audio_length = np.shape(example["audio"])[0] + batch_audio_length = int(audio_length / pixel_values.size()[0] * batch_video_length) + new_examples["audio"].append(example["audio"][:batch_audio_length]) + new_examples["fps"].append(example.get("fps", 24)) + + mask = get_random_mask(new_examples["pixel_values"][-1].size(), image_start_only=True) + mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + # Wan 2.1 use 0 for masked pixels + # + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask + new_examples["mask_pixel_values"].append(mask_pixel_values) + new_examples["mask"].append(mask) + + clip_pixel_values = new_examples["pixel_values"][-1][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + new_examples["clip_pixel_values"].append(clip_pixel_values) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + new_examples["mask_pixel_values"] = torch.stack([example for example in new_examples["mask_pixel_values"]]) + new_examples["mask"] = torch.stack([example for example in new_examples["mask"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + + # Pad audio to same length and stack + new_examples["audio"] = torch.stack([example for example in new_examples["audio"]]) + new_examples["fps"] = new_examples["fps"] + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + # Use processor for LTX-2.3 prompt encoding (same as pipeline) + if processor is not None: + prompt_ids = processor( + text=new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + truncation=True, + return_tensors="pt" + ) + else: + # Fallback to tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + text_encoder_outputs = text_encoder( + input_ids=prompt_ids.input_ids, + attention_mask=prompt_ids.attention_mask, + output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + + # Pack text embeddings to 3D tensor (flatten last two dims) + prompt_embeds = text_encoder_hidden_states.flatten(2) + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = prompt_embeds + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from packaging.version import parse as parse_version + + from videox_fun.dist import set_multi_gpus_devices, shard_model + + if parse_version(transformers.__version__) <= parse_version("4.51.3"): + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.model.layers) + else: + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + + if args.use_ema: + ema_transformer3d.to(accelerator.device) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + audio_vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + vocoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + connectors.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup vae computation + vae_stream_1 = torch.cuda.Stream() + else: + vae_stream_1 = None + + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + audio = batch["audio"] + fps = batch["fps"][0] if batch["fps"] else 24 # Use fps from dataset + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n + 1. + temp_n_frames = (temp_n_frames - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n + 1. + actual_video_length = (actual_video_length - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + audio_vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values) + else: + latents = _batch_encode_vae(pixel_values) + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + # Get latent dimensions from VAE output for later use + bsz, _, latent_num_frames, latent_height, latent_width = latents.size() + + # Encode audio to latents + with torch.no_grad(): + audio_batch = audio.to(device=accelerator.device, dtype=torch.float32) + # audio_batch shape: [batch, channels, samples] or [batch, samples] + if audio_batch.ndim == 2: + audio_batch = audio_batch.unsqueeze(1) # [batch, 1, samples] + audio_batch = audio_batch.repeat(1, 2, 1) if audio_batch.dim() == 3 else audio_batch.repeat(2, 1) + + # Convert audio waveform to log-mel spectrogram (following official LTX-2 AudioProcessor) + # mel_transform input: [batch, channels, samples] -> output: [batch, channels, n_mels, time] + mel_spec = mel_transform.to(accelerator.device)(audio_batch) + mel_spec = torch.log(mel_spec.clamp(min=1e-5)) + mel_spectrogram = mel_spec.permute(0, 1, 3, 2).contiguous() # [batch, channels, time, n_mels] + + # Ensure mel spectrogram has the correct number of channels + if mel_spectrogram.shape[1] < audio_in_channels: + mel_spectrogram = mel_spectrogram.repeat(1, audio_in_channels, 1, 1) + elif mel_spectrogram.shape[1] > audio_in_channels: + mel_spectrogram = mel_spectrogram[:, :audio_in_channels, :, :] + + # Encode mel spectrogram to latents using audio_vae + mel_spectrogram = mel_spectrogram.to(dtype=weight_dtype) + audio_encoder_output = audio_vae.encode(mel_spectrogram) + audio_latents_raw = audio_encoder_output.latent_dist.sample() + # audio_latents_raw shape: [batch, latent_channels, latent_time, latent_mel] + + # Get the actual audio_num_frames from encoded latents + audio_num_frames = audio_latents_raw.shape[2] + + # Pack audio latents FIRST, then normalize + # This is the correct order as per pipeline implementation + audio_latents = _pack_audio_latents(audio_latents_raw) + # audio_latents shape: [batch, latent_time, latent_channels * latent_mel] + + # Normalize audio latents (after packing) + audio_latents = _normalize_audio_latents( + audio_latents, audio_vae.latents_mean, audio_vae.latents_std + ) + + if args.low_vram: + vae.to('cpu') + audio_vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + connectors.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device, dtype=weight_dtype) + prompt_attention_mask = batch['encoder_attention_mask'].to(device=latents.device) + else: + with torch.no_grad(): + # Use processor for LTX-2.3 prompt encoding (same as pipeline) + if processor is not None: + prompt_ids = processor( + text=batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + return_tensors="pt" + ) + else: + # Fallback to tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids.to(latents.device) + prompt_attention_mask = prompt_ids.attention_mask.to(latents.device) + + # Get text encoder hidden states + text_encoder_outputs = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + + # Pack text embeddings to 3D tensor (flatten last two dims) + prompt_embeds = text_encoder_hidden_states.flatten(2) # [B, seq_len, hidden_dim * num_layers] + prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + + # Use connectors to process prompt embeddings + with torch.no_grad(): + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors( + prompt_embeds, prompt_attention_mask, padding_side=tokenizer.padding_side + ) + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + connectors.to('cpu') + torch.cuda.empty_cache() + + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + audio_noise = torch.randn(audio_latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Get transformer config for patch sizes + patch_size = getattr(accelerator.unwrap_model(transformer3d).config, 'patch_size', 1) + patch_size_t = getattr(accelerator.unwrap_model(transformer3d).config, 'patch_size_t', 1) + + # ------------------ I2V Conditioning Mask ------------------ + # Create conditioning mask for I2V training + # conditioning_mask: 1 for condition frames (first frame), 0 for frames to generate + conditioning_mask = None + + if args.i2v_ratio > 0: + # Randomly select samples for I2V based on i2v_ratio + i2v_prob = torch.rand(bsz, generator=torch_rng, device=latents.device) + is_i2v_sample = i2v_prob < args.i2v_ratio + + if is_i2v_sample.any(): + # Create conditioning mask: [B, 1, F, H, W] + # First frame is condition (mask=1), rest are to generate (mask=0) + conditioning_mask = torch.zeros( + (bsz, 1, latent_num_frames, latent_height, latent_width), + device=latents.device, dtype=latents.dtype + ) + conditioning_mask[is_i2v_sample, :, 0, :, :] = 1.0 + + # ------------------ Video Latents ------------------ + # Normalize video latents + latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std, vae.config.scaling_factor) + # Add noise according to flow matching + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + + if conditioning_mask is not None: + # I2V mode: first frame is condition, apply noise differently + # For condition frames: keep clean (or add slight noise if i2v_noise_scale > 0) + # For frames to generate: apply full noise + if args.i2v_noise_scale > 0: + # Add slight noise to condition frame + noisy_latents = latents * conditioning_mask * (1 - args.i2v_noise_scale) + \ + noise * conditioning_mask * args.i2v_noise_scale + \ + ((1.0 - sigmas) * latents + sigmas * noise) * (1 - conditioning_mask) + else: + # Keep condition frame clean + noisy_latents = latents * conditioning_mask + \ + ((1.0 - sigmas) * latents + sigmas * noise) * (1 - conditioning_mask) + else: + # T2V mode: all frames get noise + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + target = noise - latents + noisy_latents_packed = _pack_latents(noisy_latents, patch_size, patch_size_t) + + # Pack conditioning mask if present + if conditioning_mask is not None: + conditioning_mask_packed = _pack_latents(conditioning_mask, patch_size, patch_size_t).squeeze(-1) + else: + conditioning_mask_packed = None + + # ------------------ Audio Latents ------------------ + # Add noise to audio latents for training (flow matching) + audio_sigmas = get_sigmas(timesteps, n_dim=audio_latents.ndim, dtype=audio_latents.dtype) + noisy_audio_latents = (1.0 - audio_sigmas) * audio_latents + audio_sigmas * audio_noise + audio_target = audio_noise - audio_latents + + # -------- Timesteps Process and RoPE Process -------- + # Prepare timestep + # For T2V: use batch-level timestep (same as inference) + # For I2V: condition frames have timestep=0, generate frames have normal timestep + if conditioning_mask_packed is not None: + # I2V mode: video_timestep has shape [B, S] where S is sequence length + # condition frames (mask=1) get timestep 0, generate frames get normal timestep + video_timestep = timesteps.unsqueeze(-1) * (1 - conditioning_mask_packed) + else: + # T2V mode: use batch-level timestep + video_timestep = timesteps + audio_timestep = timesteps + + # Prepare RoPE coordinates + video_coords = accelerator.unwrap_model(transformer3d).rope.prepare_video_coords( + bsz, latent_num_frames, latent_height, latent_width, latents.device, fps=fps + ) + audio_coords = accelerator.unwrap_model(transformer3d).audio_rope.prepare_audio_coords( + bsz, audio_num_frames, audio_latents.device + ) + + # -------- Forward -------- + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred_video, noise_pred_audio = transformer3d( + hidden_states=noisy_latents_packed, + audio_hidden_states=noisy_audio_latents, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=audio_timestep, + sigma=timesteps, # LTX-2.3 uses sigma + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=fps, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=False, + return_dict=False, + ) + + # Unpack predictions for loss computation + noise_pred = _unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + patch_size, + patch_size_t, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # Video loss + video_loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + + if args.motion_sub_loss and noise_pred.size()[2] > 2: + gt_sub_noise = noise_pred[:, :, 1:].float() - noise_pred[:, :, :-1].float() + pre_sub_noise = target[:, :, 1:].float() - target[:, :, :-1].float() + sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean") + video_loss = video_loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio + + # Audio loss + audio_weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=audio_sigmas) + audio_loss = F.mse_loss(noise_pred_audio.float(), audio_target.float(), reduction='none') + if audio_weighting is not None: + # Expand weighting to match audio shape + while audio_weighting.ndim < audio_loss.ndim: + audio_weighting = audio_weighting.unsqueeze(-1) + audio_loss = audio_loss * audio_weighting + audio_loss = audio_loss.mean() + + # Combined loss (equal weighting for video and audio) + loss = 0.5 * video_loss + 0.5 * audio_loss + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + if not args.use_deepspeed and not args.use_fsdp: + trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None] + trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2) + max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step) + if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start: + actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10) + else: + actual_max_grad_norm = max_grad_norm + else: + actual_max_grad_norm = args.max_grad_norm + + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start: + for name, param in transformer3d.named_parameters(): + if param.requires_grad: + writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step) + + norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm) + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step) + writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + if args.use_ema: + ema_transformer3d.step(transformer3d.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + audio_vae, + text_encoder, + tokenizer, + processor, + connectors, + vocoder, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + audio_vae, + text_encoder, + tokenizer, + processor, + connectors, + vocoder, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer3d = unwrap_model(transformer3d) + if args.use_ema: + ema_transformer3d.copy_to(transformer3d.parameters()) + + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/ltx2.3/train.sh b/scripts/ltx2.3/train.sh new file mode 100644 index 00000000..e51233e4 --- /dev/null +++ b/scripts/ltx2.3/train.sh @@ -0,0 +1,40 @@ +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ltx2/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=121 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ltx2.3" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --low_vram \ + --trainable_modules "." \ No newline at end of file diff --git a/scripts/ltx2.3/train_lora.py b/scripts/ltx2.3/train_lora.py new file mode 100644 index 00000000..5ded5c9f --- /dev/null +++ b/scripts/ltx2.3/train_lora.py @@ -0,0 +1,2180 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image_video import (ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.data.dataset_video import VideoSpeechDataset +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + Gemma3ForConditionalGeneration, Gemma3Processor, + LTX2TextConnectors, + LTX2VideoTransformer3DModel, LTX2VocoderWithBWE) +from videox_fun.pipeline import LTX2Pipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.lora_utils import (convert_peft_lora_to_kohya_lora, + create_network, merge_lora, + unmerge_lora) +from videox_fun.utils.utils import (calculate_dimensions, get_image_latent, + get_image_to_video_latent, + save_videos_grid, + save_videos_with_audio_grid) + +if is_wandb_available(): + import wandb + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +# LTX2 helper functions for packing text embeddings and latents +def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: + """Packs and normalizes text encoder hidden states, respecting padding.""" + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] + + # Compute masked mean + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to 3D tensor + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + """Packs latents [B, C, F, H, W] into token sequence [B, S, D].""" + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + """Unpacks token sequence [B, S, D] back to latents [B, C, F, H, W].""" + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + """Normalizes latents across the channel dimension [B, C, F, H, W].""" + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + +def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None +) -> torch.Tensor: + """Packs audio latents [B, C, L, M] into token sequence.""" + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + +def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + """Normalizes audio latents.""" + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, audio_vae, text_encoder, tokenizer, processor, connectors, vocoder, transformer3d, network, args, accelerator, weight_dtype, global_step): + try: + is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' + if is_deepspeed: + origin_config = transformer3d.config + transformer3d.config = accelerator.unwrap_model(transformer3d).config + with torch.no_grad(), torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + logger.info("Running validation... ") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + pipeline = LTX2Pipeline( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + connectors=connectors, + transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, + vocoder=vocoder, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + rank_seed = args.seed + accelerator.process_index + generator = torch.Generator(device=accelerator.device).manual_seed(rank_seed) + logger.info(f"Rank {accelerator.process_index} using seed: {rank_seed}") + + for i in range(len(args.validation_prompts)): + output = pipeline( + args.validation_prompts[i], + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts", + height = args.video_sample_size, + width = args.video_sample_size, + num_frames = args.video_sample_n_frames, + num_inference_steps = 25, + guidance_scale = 3.0, + audio_guidance_scale = 7.0, + stg_scale = 1.0, + audio_stg_scale = 1.0, + modality_scale = 3.0, + audio_modality_scale = 3.0, + guidance_rescale = 0.7, + audio_guidance_rescale = 0.7, + spatio_temporal_guidance_blocks = [28], + generator = generator, + ) + sample = output.videos + audio = output.audio + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) + save_videos_with_audio_grid( + sample, + audio, + os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + fps=24, + audio_sample_rate=sr, + ) + + del pipeline + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if is_deepspeed: + transformer3d.config = origin_config + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error on rank {accelerator.process_index} with info {e}") + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_paths", + type=str, + default=None, + nargs="+", + help=("A set of control videos evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--network_alpha", + type=int, + default=64, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--use_peft_lora", action="store_true", help="Whether or not to use peft lora." + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + parser.add_argument("--save_state", action="store_true", help="Whether or not to save state.") + + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--i2v_ratio", + type=float, + default=0.5, + help=( + 'Ratio of I2V samples in training. 0.0 = pure T2V, 1.0 = pure I2V, ' + '0.5 = 50%% T2V + 50%% I2V (default).' + ), + ) + parser.add_argument( + "--i2v_noise_scale", + type=float, + default=0.0, + help=( + 'Noise scale for I2V first frame conditioning. ' + '0.0 means first frame is kept clean (default). ' + 'Higher values add slight noise to the condition frame.' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--lora_skip_name", + type=str, + default=None, + help=("The module is not trained in loras. "), + ) + parser.add_argument( + "--target_name", + type=str, + default=None, + help=("The module is trained in loras. "), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + # Get Processor + processor = Gemma3Processor.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="processor", + ) + + # Get Tokenizer from processor + tokenizer = processor.tokenizer + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKLLTX2Video.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype, + ) + vae.eval() + audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="audio_vae", + torch_dtype=weight_dtype, + ) + audio_vae.eval() + + # Connectors + connectors = LTX2TextConnectors.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="connectors", + torch_dtype=weight_dtype, + ) + # Vocoder + vocoder = LTX2VocoderWithBWE.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vocoder", + torch_dtype=weight_dtype, + ) + + # Get Transformer + transformer3d = LTX2VideoTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + low_cpu_mem_usage=True, + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + audio_vae.requires_grad_(False) + connectors.requires_grad_(False) + vocoder.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + # Lora will work with this... + if args.use_peft_lora: + from peft import (LoraConfig, get_peft_model_state_dict, + inject_adapter_in_model) + lora_config = LoraConfig(r=args.rank, lora_alpha=args.network_alpha, target_modules=args.target_name.split(",")) + transformer3d = inject_adapter_in_model(lora_config, transformer3d) + + network = None + else: + network = create_network( + 1.0, + args.rank, + args.network_alpha, + text_encoder, + transformer3d, + neuron_dropout=None, + target_name=args.target_name, + skip_name=args.lora_skip_name, + ) + network = network.to(weight_dtype) + network.apply_to(text_encoder, transformer3d, args.train_text_encoder and not args.training_with_video_token_length, True) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0 or zero_stage == 3: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + if args.use_peft_lora: + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(models[-1]), accelerate_state_dict) + network_state_dict_kohya = convert_peft_lora_to_kohya_lora(network_state_dict) + safetensor_kohya_format_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model_compatible_with_comfyui.safetensors") + save_model(safetensor_kohya_format_save_path, network_state_dict_kohya) + else: + network_state_dict = {} + for key in accelerate_state_dict: + if "network" in key: + network_state_dict[key.replace("network.", "")] = accelerate_state_dict[key].to(weight_dtype) + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + if args.use_peft_lora: + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(models[-1]), accelerate_state_dict) + network_state_dict_kohya = convert_peft_lora_to_kohya_lora(network_state_dict) + safetensor_kohya_format_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model_compatible_with_comfyui.safetensors") + save_model(safetensor_kohya_format_save_path, network_state_dict_kohya) + else: + network_state_dict = {} + for key in accelerate_state_dict: + if "network" in key: + network_state_dict[key.replace("network.", "")] = accelerate_state_dict[key].to(weight_dtype) + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + if not args.use_deepspeed: + for _ in range(len(weights)): + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except Exception: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + if args.use_peft_lora: + logging.info("Add peft parameters") + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + print(trainable_params[0]) + trainable_params_optim = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + else: + logging.info("Add network parameters") + trainable_params = list(filter(lambda p: p.requires_grad, network.parameters())) + trainable_params_optim = network.prepare_optimizer_params(args.learning_rate / 2, args.learning_rate, args.learning_rate) + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + train_dataset = VideoSpeechDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + enable_bucket=args.enable_bucket, enable_inpaint=True, audio_sr=getattr(audio_vae.config, 'sample_rate', 16000), + ) + + # Pre-create mel spectrogram transform (avoid recreating per iteration) + audio_sampling_rate = getattr(audio_vae.config, 'sample_rate', 16000) + audio_hop_length = getattr(audio_vae.config, 'mel_hop_length', 160) + audio_mel_bins = getattr(audio_vae.config, 'mel_bins', 64) + audio_in_channels = getattr(audio_vae.config, 'in_channels', 2) + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_sampling_rate, + n_fft=1024, + win_length=1024, + hop_length=audio_hop_length, + f_min=0.0, + f_max=audio_sampling_rate / 2.0, + n_mels=audio_mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale='slaney', + norm='slaney', + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_length_to_frame_num(token_length): + if args.video_sample_size > 256: + sample_sizes = list(range(256, args.video_sample_size + 1, 128)) + + if sample_sizes[-1] != args.video_sample_size: + sample_sizes.append(args.video_sample_size) + else: + sample_sizes = [args.video_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes + } + + return length_to_frame_num + + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["text"] = [] + new_examples["audio"] = [] + new_examples["fps"] = [] + + # Used in Inpaint mode + new_examples["mask_pixel_values"] = [] + new_examples["mask"] = [] + new_examples["clip_pixel_values"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + f, h, w, c = np.shape(pixel_value) + + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + # The video will be resized to a lower resolution than its own. + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + local_video_sample_size = np.random.choice(choice_list) + batch_video_length = length_to_frame_num[local_video_sample_size] + random_downsample_ratio = args.video_sample_size / local_video_sample_size + else: + random_downsample_ratio = get_random_downsample_ratio(args.video_sample_size) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 64) * 64 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 64) * 64 for x in closest_size] + + min_example_length = min( + [example["pixel_values"].shape[0] for example in examples] + ) + batch_video_length = int(min(batch_video_length, min_example_length)) + + # Magvae needs the number of frames to be 4n + 1. + batch_video_length = (batch_video_length - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + + if batch_video_length <= 0: + batch_video_length = 1 + + for example in examples: + if args.fix_sample_size is not None: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)[:batch_video_length]) + new_examples["text"].append(example["text"]) + + audio_length = np.shape(example["audio"])[0] + batch_audio_length = int(audio_length / pixel_values.size()[0] * batch_video_length) + new_examples["audio"].append(example["audio"][:batch_audio_length]) + new_examples["fps"].append(example.get("fps", 24)) + + mask = get_random_mask(new_examples["pixel_values"][-1].size(), image_start_only=True) + mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + # Wan 2.1 use 0 for masked pixels + # + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask + new_examples["mask_pixel_values"].append(mask_pixel_values) + new_examples["mask"].append(mask) + + clip_pixel_values = new_examples["pixel_values"][-1][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + new_examples["clip_pixel_values"].append(clip_pixel_values) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + new_examples["mask_pixel_values"] = torch.stack([example for example in new_examples["mask_pixel_values"]]) + new_examples["mask"] = torch.stack([example for example in new_examples["mask"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + + # Pad audio to same length and stack + new_examples["audio"] = torch.stack([example for example in new_examples["audio"]]) + new_examples["fps"] = new_examples["fps"] + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + # Use processor for LTX-2.3 prompt encoding (same as pipeline) + if processor is not None: + prompt_ids = processor( + text=new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + truncation=True, + return_tensors="pt" + ) + else: + # Fallback to tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + text_encoder_outputs = text_encoder( + input_ids=prompt_ids.input_ids, + attention_mask=prompt_ids.attention_mask, + output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + + # Pack text embeddings to 3D tensor (flatten last two dims) + prompt_embeds = text_encoder_hidden_states.flatten(2) + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = prompt_embeds + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + if args.use_peft_lora: + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + else: + transformer3d.network = network + transformer3d = transformer3d.to(dtype=weight_dtype) + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from packaging.version import parse as parse_version + + from videox_fun.dist import set_multi_gpus_devices, shard_model + + if parse_version(transformers.__version__) <= parse_version("4.51.3"): + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.model.layers) + else: + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers) + text_encoder = shard_fn(text_encoder) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + audio_vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + vocoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + connectors.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + checkpoint_folder_path = os.path.join(args.output_dir, path) + pkl_path = os.path.join(checkpoint_folder_path, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + if zero_stage != 3 and not args.use_fsdp: + from safetensors.torch import load_file + state_dict = load_file(os.path.join(checkpoint_folder_path, "lora_diffusion_pytorch_model.safetensors"), device=str(accelerator.device)) + m, u = accelerator.unwrap_model(network).load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + optimizer_file_pt = os.path.join(checkpoint_folder_path, "optimizer.pt") + optimizer_file_bin = os.path.join(checkpoint_folder_path, "optimizer.bin") + optimizer_file_to_load = None + + if os.path.exists(optimizer_file_pt): + optimizer_file_to_load = optimizer_file_pt + elif os.path.exists(optimizer_file_bin): + optimizer_file_to_load = optimizer_file_bin + + if optimizer_file_to_load: + try: + accelerator.print(f"Loading optimizer state from {optimizer_file_to_load}") + optimizer_state = torch.load(optimizer_file_to_load, map_location=accelerator.device) + optimizer.load_state_dict(optimizer_state) + accelerator.print("Optimizer state loaded successfully.") + except Exception as e: + accelerator.print(f"Failed to load optimizer state from {optimizer_file_to_load}: {e}") + + scheduler_file_pt = os.path.join(checkpoint_folder_path, "scheduler.pt") + scheduler_file_bin = os.path.join(checkpoint_folder_path, "scheduler.bin") + scheduler_file_to_load = None + + if os.path.exists(scheduler_file_pt): + scheduler_file_to_load = scheduler_file_pt + elif os.path.exists(scheduler_file_bin): + scheduler_file_to_load = scheduler_file_bin + + if scheduler_file_to_load: + try: + accelerator.print(f"Loading scheduler state from {scheduler_file_to_load}") + scheduler_state = torch.load(scheduler_file_to_load, map_location=accelerator.device) + lr_scheduler.load_state_dict(scheduler_state) + accelerator.print("Scheduler state loaded successfully.") + except Exception as e: + accelerator.print(f"Failed to load scheduler state from {scheduler_file_to_load}: {e}") + + if hasattr(accelerator, 'scaler') and accelerator.scaler is not None: + scaler_file = os.path.join(checkpoint_folder_path, "scaler.pt") + if os.path.exists(scaler_file): + try: + accelerator.print(f"Loading GradScaler state from {scaler_file}") + scaler_state = torch.load(scaler_file, map_location=accelerator.device) + accelerator.scaler.load_state_dict(scaler_state) + accelerator.print("GradScaler state loaded successfully.") + except Exception as e: + accelerator.print(f"Failed to load GradScaler state: {e}") + + else: + accelerator.load_state(checkpoint_folder_path) + accelerator.print("accelerator.load_state() completed for zero_stage 3.") + + else: + initial_global_step = 0 + + # function for saving/removing + def save_model(ckpt_file, unwrapped_nw): + os.makedirs(args.output_dir, exist_ok=True) + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + if isinstance(unwrapped_nw, dict): + from safetensors.torch import save_file + save_file(unwrapped_nw, ckpt_file, metadata={"format": "pt"}) + return ckpt_file + unwrapped_nw.save_weights(ckpt_file, weight_dtype, None) + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup vae computation + vae_stream_1 = torch.cuda.Stream() + else: + vae_stream_1 = None + + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + audio = batch["audio"] + fps = batch["fps"][0] if batch["fps"] else 24 # Use fps from dataset + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n + 1. + temp_n_frames = (temp_n_frames - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n + 1. + actual_video_length = (actual_video_length - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + audio_vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values) + else: + latents = _batch_encode_vae(pixel_values) + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + # Get latent dimensions from VAE output for later use + bsz, _, latent_num_frames, latent_height, latent_width = latents.size() + + # Encode audio to latents + with torch.no_grad(): + audio_batch = audio.to(device=accelerator.device, dtype=torch.float32) + # audio_batch shape: [batch, channels, samples] or [batch, samples] + if audio_batch.ndim == 2: + audio_batch = audio_batch.unsqueeze(1) # [batch, 1, samples] + audio_batch = audio_batch.repeat(1, 2, 1) if audio_batch.dim() == 3 else audio_batch.repeat(2, 1) + + # Convert audio waveform to log-mel spectrogram (following official LTX-2 AudioProcessor) + # mel_transform input: [batch, channels, samples] -> output: [batch, channels, n_mels, time] + mel_spec = mel_transform.to(accelerator.device)(audio_batch) + mel_spec = torch.log(mel_spec.clamp(min=1e-5)) + mel_spectrogram = mel_spec.permute(0, 1, 3, 2).contiguous() # [batch, channels, time, n_mels] + + # Ensure mel spectrogram has the correct number of channels + if mel_spectrogram.shape[1] < audio_in_channels: + mel_spectrogram = mel_spectrogram.repeat(1, audio_in_channels, 1, 1) + elif mel_spectrogram.shape[1] > audio_in_channels: + mel_spectrogram = mel_spectrogram[:, :audio_in_channels, :, :] + + # Encode mel spectrogram to latents using audio_vae + mel_spectrogram = mel_spectrogram.to(dtype=weight_dtype) + audio_encoder_output = audio_vae.encode(mel_spectrogram) + audio_latents_raw = audio_encoder_output.latent_dist.sample() + # audio_latents_raw shape: [batch, latent_channels, latent_time, latent_mel] + + # Get the actual audio_num_frames from encoded latents + audio_num_frames = audio_latents_raw.shape[2] + + # Pack audio latents FIRST, then normalize + # This is the correct order as per pipeline implementation + audio_latents = _pack_audio_latents(audio_latents_raw) + # audio_latents shape: [batch, latent_time, latent_channels * latent_mel] + + # Normalize audio latents (after packing) + audio_latents = _normalize_audio_latents( + audio_latents, audio_vae.latents_mean, audio_vae.latents_std + ) + + if args.low_vram: + vae.to('cpu') + audio_vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + connectors.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device, dtype=weight_dtype) + prompt_attention_mask = batch['encoder_attention_mask'].to(device=latents.device) + else: + with torch.no_grad(): + # Use processor for LTX-2.3 prompt encoding (same as pipeline) + if processor is not None: + prompt_ids = processor( + text=batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + return_tensors="pt" + ) + else: + # Fallback to tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids.to(latents.device) + prompt_attention_mask = prompt_ids.attention_mask.to(latents.device) + + # Get text encoder hidden states + text_encoder_outputs = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + + # Pack text embeddings to 3D tensor (flatten last two dims) + prompt_embeds = text_encoder_hidden_states.flatten(2) # [B, seq_len, hidden_dim * num_layers] + prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + + # Use connectors to process prompt embeddings + with torch.no_grad(): + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors( + prompt_embeds, prompt_attention_mask, padding_side=tokenizer.padding_side + ) + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + connectors.to('cpu') + torch.cuda.empty_cache() + + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + audio_noise = torch.randn(audio_latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Get transformer config for patch sizes + patch_size = getattr(accelerator.unwrap_model(transformer3d).config, 'patch_size', 1) + patch_size_t = getattr(accelerator.unwrap_model(transformer3d).config, 'patch_size_t', 1) + + # ------------------ I2V Conditioning Mask ------------------ + # Create conditioning mask for I2V training + # conditioning_mask: 1 for condition frames (first frame), 0 for frames to generate + conditioning_mask = None + + if args.i2v_ratio > 0: + # Randomly select samples for I2V based on i2v_ratio + i2v_prob = torch.rand(bsz, generator=torch_rng, device=latents.device) + is_i2v_sample = i2v_prob < args.i2v_ratio + + if is_i2v_sample.any(): + # Create conditioning mask: [B, 1, F, H, W] + # First frame is condition (mask=1), rest are to generate (mask=0) + conditioning_mask = torch.zeros( + (bsz, 1, latent_num_frames, latent_height, latent_width), + device=latents.device, dtype=latents.dtype + ) + conditioning_mask[is_i2v_sample, :, 0, :, :] = 1.0 + + # ------------------ Video Latents ------------------ + # Normalize video latents + latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std, vae.config.scaling_factor) + # Add noise according to flow matching + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + + if conditioning_mask is not None: + # I2V mode: first frame is condition, apply noise differently + # For condition frames: keep clean (or add slight noise if i2v_noise_scale > 0) + # For frames to generate: apply full noise + if args.i2v_noise_scale > 0: + # Add slight noise to condition frame + noisy_latents = latents * conditioning_mask * (1 - args.i2v_noise_scale) + \ + noise * conditioning_mask * args.i2v_noise_scale + \ + ((1.0 - sigmas) * latents + sigmas * noise) * (1 - conditioning_mask) + else: + # Keep condition frame clean + noisy_latents = latents * conditioning_mask + \ + ((1.0 - sigmas) * latents + sigmas * noise) * (1 - conditioning_mask) + else: + # T2V mode: all frames get noise + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + target = noise - latents + noisy_latents_packed = _pack_latents(noisy_latents, patch_size, patch_size_t) + + # Pack conditioning mask if present + if conditioning_mask is not None: + conditioning_mask_packed = _pack_latents(conditioning_mask, patch_size, patch_size_t).squeeze(-1) + else: + conditioning_mask_packed = None + + # ------------------ Audio Latents ------------------ + # Add noise to audio latents for training (flow matching) + audio_sigmas = get_sigmas(timesteps, n_dim=audio_latents.ndim, dtype=audio_latents.dtype) + noisy_audio_latents = (1.0 - audio_sigmas) * audio_latents + audio_sigmas * audio_noise + audio_target = audio_noise - audio_latents + + # -------- Timesteps Process and RoPE Process -------- + # Prepare timestep + # For T2V: use batch-level timestep (same as inference) + # For I2V: condition frames have timestep=0, generate frames have normal timestep + if conditioning_mask_packed is not None: + # I2V mode: video_timestep has shape [B, S] where S is sequence length + # condition frames (mask=1) get timestep 0, generate frames get normal timestep + video_timestep = timesteps.unsqueeze(-1) * (1 - conditioning_mask_packed) + else: + # T2V mode: use batch-level timestep + video_timestep = timesteps + audio_timestep = timesteps + + # Prepare RoPE coordinates + video_coords = accelerator.unwrap_model(transformer3d).rope.prepare_video_coords( + bsz, latent_num_frames, latent_height, latent_width, latents.device, fps=fps + ) + audio_coords = accelerator.unwrap_model(transformer3d).audio_rope.prepare_audio_coords( + bsz, audio_num_frames, audio_latents.device + ) + + # -------- Forward -------- + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred_video, noise_pred_audio = transformer3d( + hidden_states=noisy_latents_packed, + audio_hidden_states=noisy_audio_latents, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=audio_timestep, + sigma=timesteps, # LTX-2.3 uses sigma + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=fps, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=False, + return_dict=False, + ) + + # Unpack predictions for loss computation + noise_pred = _unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + patch_size, + patch_size_t, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # Video loss + video_loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + + if args.motion_sub_loss and noise_pred.size()[2] > 2: + gt_sub_noise = noise_pred[:, :, 1:].float() - noise_pred[:, :, :-1].float() + pre_sub_noise = target[:, :, 1:].float() - target[:, :, :-1].float() + sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean") + video_loss = video_loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio + + # Audio loss + audio_weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=audio_sigmas) + audio_loss = F.mse_loss(noise_pred_audio.float(), audio_target.float(), reduction='none') + if audio_weighting is not None: + # Expand weighting to match audio shape + while audio_weighting.ndim < audio_loss.ndim: + audio_weighting = audio_weighting.unsqueeze(-1) + audio_loss = audio_loss * audio_weighting + audio_loss = audio_loss.mean() + + # Combined loss (equal weighting for video and audio) + loss = 0.5 * video_loss + 0.5 * audio_loss + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + if args.use_peft_lora: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(transformer3d)) + save_model(safetensor_save_path, network_state_dict) + + safetensor_kohya_format_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}-compatible_with_comfyui.safetensors") + network_state_dict_kohya = convert_peft_lora_to_kohya_lora(network_state_dict) + save_model(safetensor_kohya_format_save_path, network_state_dict_kohya) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + audio_vae, + text_encoder, + tokenizer, + processor, + connectors, + vocoder, + transformer3d, + network, + args, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + vae, + audio_vae, + text_encoder, + tokenizer, + processor, + connectors, + vocoder, + transformer3d, + network, + args, + accelerator, + weight_dtype, + global_step, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + if args.use_peft_lora: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(transformer3d)) + save_model(safetensor_save_path, network_state_dict) + + safetensor_kohya_format_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}-compatible_with_comfyui.safetensors") + network_state_dict_kohya = convert_peft_lora_to_kohya_lora(network_state_dict) + save_model(safetensor_kohya_format_save_path, network_state_dict_kohya) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/ltx2.3/train_lora.sh b/scripts/ltx2.3/train_lora.sh new file mode 100644 index 00000000..6f1a3b92 --- /dev/null +++ b/scripts/ltx2.3/train_lora.sh @@ -0,0 +1,41 @@ +export MODEL_NAME="models/Diffusion_Transformer/LTX-2.3-Diffusers" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ltx2/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=1 \ + --video_sample_n_frames=121 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_ltx2.3_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --rank=64 \ + --network_alpha=32 \ + --target_name="to_q,to_k,to_v,ff.0,ff.2,audio_ff.0,audio_ff.2" \ + --use_peft_lora \ + --low_vram \ No newline at end of file diff --git a/scripts/ltx2/train.py b/scripts/ltx2/train.py index ca9e8fcf..06d0d3b2 100644 --- a/scripts/ltx2/train.py +++ b/scripts/ltx2/train.py @@ -102,50 +102,7 @@ def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=Non t = 1 / (1 + torch.exp(-u)) * (high - low) + low return torch.clip(t.to(torch.int32), low, high - 1) -# LTX2 helper functions for packing text embeddings and latents -def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, -) -> torch.Tensor: - """Packs and normalizes text encoder hidden states, respecting padding.""" - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - mask = token_indices < sequence_lengths[:, None] - elif padding_side == "left": - start_indices = seq_len - sequence_lengths[:, None] - mask = token_indices >= start_indices - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] - - # Compute masked mean - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to 3D tensor - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - +# LTX2 helper functions for packing latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: """Packs latents [B, C, F, H, W] into token sequence [B, S, D].""" batch_size, num_channels, num_frames, height, width = latents.shape @@ -245,17 +202,19 @@ def log_validation(vae, audio_vae, text_encoder, tokenizer, connectors, vocoder, for i in range(len(args.validation_prompts)): output = pipeline( args.validation_prompts[i], - num_frames = args.video_sample_n_frames, - negative_prompt = "bad detailed", - height = args.video_sample_size, - width = args.video_sample_size, - generator = generator, + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts", + height = args.video_sample_size, + width = args.video_sample_size, + num_frames = args.video_sample_n_frames, + frame_rate = 24, num_inference_steps = 25, guidance_scale = 4.5, + generator = generator, ) sample = output.videos audio = output.audio os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) save_videos_with_audio_grid( sample, audio, @@ -264,7 +223,7 @@ def log_validation(vae, audio_vae, text_encoder, tokenizer, connectors, vocoder, f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ), fps=24, - audio_sample_rate=24000, + audio_sample_rate=sr, ) del pipeline @@ -1407,15 +1366,8 @@ def _create_special_list(length): text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - # Pack text embeddings (normalized and flattened) - sequence_lengths = prompt_ids.attention_mask.sum(dim=-1) - prompt_embeds = _pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=text_encoder_hidden_states.device, - padding_side=tokenizer.padding_side, - scale_factor=8, - ) + # Pack text embeddings (flatten to 3D, same as pipeline) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3) new_examples['encoder_attention_mask'] = prompt_ids.attention_mask new_examples['encoder_hidden_states'] = prompt_embeds @@ -1767,22 +1719,13 @@ def _batch_encode_vae(pixel_values): text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - # Pack text embeddings (normalized and flattened) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - prompt_embeds = _pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=latents.device, - padding_side=tokenizer.padding_side, - scale_factor=8, - ) - prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + # Pack text embeddings (flatten to 3D, same as pipeline) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=weight_dtype) # Use connectors to process prompt embeddings with torch.no_grad(): - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.device, prompt_embeds.dtype)) * -1000000.0 connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer.padding_side ) if args.low_vram and not args.enable_text_encoder_in_dataloader: @@ -1903,17 +1846,28 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): audio_coords = accelerator.unwrap_model(transformer3d).audio_rope.prepare_audio_coords( bsz, audio_num_frames, audio_latents.device ) - # -------- Forward -------- # Predict the noise residual + # Expand timestep to match batch dimension (same as pipeline) + if video_timestep.ndim == 1: + video_timestep_expanded = video_timestep.expand(noisy_latents_packed.shape[0]) + else: + video_timestep_expanded = video_timestep + if audio_timestep.ndim == 1: + audio_timestep_expanded = audio_timestep.expand(noisy_audio_latents.shape[0]) + else: + audio_timestep_expanded = audio_timestep + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): noise_pred_video, noise_pred_audio = transformer3d( hidden_states=noisy_latents_packed, audio_hidden_states=noisy_audio_latents, encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=video_timestep, - audio_timestep=audio_timestep, + timestep=video_timestep_expanded, + audio_timestep=audio_timestep_expanded, + sigma=video_timestep_expanded, # LTX-2.3 uses sigma for cross attention modulation + audio_sigma=audio_timestep_expanded, encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1923,6 +1877,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=False, + attention_kwargs=None, return_dict=False, ) diff --git a/scripts/ltx2/train.sh b/scripts/ltx2/train.sh index 5575af39..d65426ed 100644 --- a/scripts/ltx2/train.sh +++ b/scripts/ltx2/train.sh @@ -14,7 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/ltx2/train.py \ --video_sample_size=640 \ --token_sample_size=640 \ --video_sample_stride=1 \ - --video_sample_n_frames=81 \ + --video_sample_n_frames=121 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ diff --git a/scripts/ltx2/train_lora.py b/scripts/ltx2/train_lora.py index 90bb799b..e87817d9 100644 --- a/scripts/ltx2/train_lora.py +++ b/scripts/ltx2/train_lora.py @@ -103,50 +103,7 @@ def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=Non t = 1 / (1 + torch.exp(-u)) * (high - low) + low return torch.clip(t.to(torch.int32), low, high - 1) -# LTX2 helper functions for packing text embeddings and latents -def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, -) -> torch.Tensor: - """Packs and normalizes text encoder hidden states, respecting padding.""" - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - mask = token_indices < sequence_lengths[:, None] - elif padding_side == "left": - start_indices = seq_len - sequence_lengths[:, None] - mask = token_indices >= start_indices - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] - - # Compute masked mean - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to 3D tensor - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - +# LTX2 helper functions for packing latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: """Packs latents [B, C, F, H, W] into token sequence [B, S, D].""" batch_size, num_channels, num_frames, height, width = latents.shape @@ -246,17 +203,19 @@ def log_validation(vae, audio_vae, text_encoder, tokenizer, connectors, vocoder, for i in range(len(args.validation_prompts)): output = pipeline( args.validation_prompts[i], - num_frames = args.video_sample_n_frames, - negative_prompt = "bad detailed", - height = args.video_sample_size, - width = args.video_sample_size, - generator = generator, + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts", + height = args.video_sample_size, + width = args.video_sample_size, + num_frames = args.video_sample_n_frames, + frame_rate = 24, num_inference_steps = 25, guidance_scale = 4.5, + generator = generator, ) sample = output.videos audio = output.audio os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) save_videos_with_audio_grid( sample, audio, @@ -265,7 +224,7 @@ def log_validation(vae, audio_vae, text_encoder, tokenizer, connectors, vocoder, f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ), fps=24, - audio_sample_rate=24000, + audio_sample_rate=sr, ) del pipeline @@ -1386,15 +1345,8 @@ def _create_special_list(length): text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - # Pack text embeddings (normalized and flattened) - sequence_lengths = prompt_ids.attention_mask.sum(dim=-1) - prompt_embeds = _pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=text_encoder_hidden_states.device, - padding_side=tokenizer.padding_side, - scale_factor=8, - ) + # Pack text embeddings (flatten to 3D, same as pipeline) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3) new_examples['encoder_attention_mask'] = prompt_ids.attention_mask new_examples['encoder_hidden_states'] = prompt_embeds @@ -1817,22 +1769,13 @@ def _batch_encode_vae(pixel_values): text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - # Pack text embeddings (normalized and flattened) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - prompt_embeds = _pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=latents.device, - padding_side=tokenizer.padding_side, - scale_factor=8, - ) - prompt_embeds = prompt_embeds.to(dtype=weight_dtype) + # Pack text embeddings (flatten to 3D, same as pipeline) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=weight_dtype) # Use connectors to process prompt embeddings with torch.no_grad(): - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.device, prompt_embeds.dtype)) * -1000000.0 connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer.padding_side ) if args.low_vram and not args.enable_text_encoder_in_dataloader: @@ -1954,14 +1897,26 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # -------- Forward -------- # Predict the noise residual + # Expand timestep to match batch dimension (same as pipeline) + if video_timestep.ndim == 1: + video_timestep_expanded = video_timestep.expand(noisy_latents_packed.shape[0]) + else: + video_timestep_expanded = video_timestep + if audio_timestep.ndim == 1: + audio_timestep_expanded = audio_timestep.expand(noisy_audio_latents.shape[0]) + else: + audio_timestep_expanded = audio_timestep + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): noise_pred_video, noise_pred_audio = transformer3d( hidden_states=noisy_latents_packed, audio_hidden_states=noisy_audio_latents, encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, - timestep=video_timestep, - audio_timestep=audio_timestep, + timestep=video_timestep_expanded, + audio_timestep=audio_timestep_expanded, + sigma=video_timestep_expanded, # LTX-2.3 uses sigma for cross attention modulation + audio_sigma=audio_timestep_expanded, encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1971,6 +1926,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=False, + attention_kwargs=None, return_dict=False, ) diff --git a/scripts/ltx2/train_lora.sh b/scripts/ltx2/train_lora.sh index b3f21e0f..125fa16a 100644 --- a/scripts/ltx2/train_lora.sh +++ b/scripts/ltx2/train_lora.sh @@ -14,7 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/ltx2/train_lora.py \ --video_sample_size=640 \ --token_sample_size=640 \ --video_sample_stride=1 \ - --video_sample_n_frames=81 \ + --video_sample_n_frames=121 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ diff --git a/scripts/process_json_add_width_and_height.py b/scripts/process_json_add_width_and_height.py new file mode 100644 index 00000000..68025a0e --- /dev/null +++ b/scripts/process_json_add_width_and_height.py @@ -0,0 +1,114 @@ +import json +import argparse +import multiprocessing as mp +from pathlib import Path +from PIL import Image +from decord import VideoReader, cpu + +# Supported file extensions +IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.gif'} +VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} + + +def process_media_sample(sample): + """Extract width and height from image or video files.""" + try: + file_path = sample.get('file_path') + if not file_path: + return sample + + ext = Path(file_path).suffix.lower() + + if ext in IMAGE_EXTENSIONS: + # Extract dimensions from image + with Image.open(file_path) as img: + width, height = img.size + elif ext in VIDEO_EXTENSIONS: + # Extract dimensions from video using decord + vr = VideoReader(file_path, ctx=cpu(0)) + width, height = vr.width, vr.height + else: + print(f"Warning: Unsupported format '{ext}' for {file_path}") + return sample + + # Update sample with extracted dimensions + sample['height'] = height + sample['width'] = width + return sample + except Exception as e: + print(f"Error processing {sample.get('file_path')}: {e}") + return sample + + +def process_json_with_multiprocessing(input_json_path, output_json_path, num_processes=None): + # Load the JSON file + with open(input_json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Extract samples based on data structure + if isinstance(data, list): + samples = data + elif isinstance(data, dict) and 'samples' in data: + samples = data['samples'] + else: + samples = [data] + + # Set number of processes + if num_processes is None: + num_processes = mp.cpu_count() + + print(f"Starting processing {len(samples)} samples using {num_processes} processes...") + + # Process samples using multiprocessing + with mp.Pool(processes=num_processes) as pool: + processed_samples = pool.map(process_media_sample, samples) + + # Reconstruct output data preserving original structure + if isinstance(data, list): + output_data = processed_samples + elif isinstance(data, dict) and 'samples' in data: + output_data = data.copy() + output_data['samples'] = processed_samples + else: + output_data = processed_samples[0] + + print(f"Successfully processed {len(processed_samples)} samples.") + with open(output_json_path, 'w', encoding='utf-8') as f: + json.dump(output_data, f, ensure_ascii=False, indent=2) + + print(f"Processing complete! Results saved to {output_json_path}") + + +if __name__ == '__main__': + # Use 'spawn' start method to avoid potential deadlocks with decord/FFmpeg in multiprocessing + mp.set_start_method('spawn', force=True) + + parser = argparse.ArgumentParser( + description="Add width and height fields to image/video metadata in JSON files." + ) + parser.add_argument( + "--input_file", + type=str, + default="datasets/X-Fun-Images-Demo/metadata.json", + help="Path to the input JSON file." + ) + parser.add_argument( + "--output_file", + type=str, + default="datasets/X-Fun-Images-Demo/metadata_add_width_height.json", + help="Path to the output JSON file." + ) + parser.add_argument( + "--num_processes", + type=int, + default=None, + help="Number of parallel processes to use. Defaults to CPU core count." + ) + + args = parser.parse_args() + + process_json_with_multiprocessing( + input_json_path=args.input_file, + output_json_path=args.output_file, + num_processes=args.num_processes + ) diff --git a/scripts/qwenimage/README_TRAIN.md b/scripts/qwenimage/README_TRAIN.md index 6b2de6b9..5b8011dd 100755 --- a/scripts/qwenimage/README_TRAIN.md +++ b/scripts/qwenimage/README_TRAIN.md @@ -1,38 +1,180 @@ -## Training Code +# Qwen-Image Full Parameter Training Guide -We can choose whether to use deepspeed or fsdp in qwen-image, which can save a lot of video memory. +This document provides a complete workflow for full parameter training of Qwen-Image Diffusion Transformer, including environment configuration, data preparation, distributed training, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Full Parameter Training](#3-full-parameter-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameter Parsing](#41-inference-parameter-parsing) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt ``` -Without deepspeed: +**Method 2: Manual Dependency Installation** -Training qwen-image without DeepSpeed may result in insufficient GPU memory. -```sh -export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. Full Parameter Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download Qwen-Image official weights +modelscope download --model Qwen/Qwen-Image --local_dir models/Diffusion_Transformer/Qwen-Image +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. # export NCCL_IB_DISABLE=1 # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/qwenimage/train.py \ +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -46,7 +188,7 @@ accelerate launch --mixed_precision="bf16" scripts/qwenimage/train.py \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_qwenimage" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -58,7 +200,42 @@ accelerate launch --mixed_precision="bf16" scripts/qwenimage/train.py \ --trainable_modules "." ``` -With Deepspeed Zero-2: +### 3.3 Common Training Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Path to pretrained model | `models/Diffusion_Transformer/Qwen-Image` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Samples per batch | 1 | +| `--image_sample_size` | Maximum training resolution, auto bucketing | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (equivalent to larger batch) | 1 | +| `--dataloader_num_workers` | DataLoader subprocesses | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate | 2e-05 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir` | +| `--gradient_checkpointing` | Enable activation checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon value | 1e-10 | +| `--vae_mini_batch` | Mini-batch size for VAE encoding | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training: trains entire images grouped by resolution without center cropping | - | +| `--random_hw_adapt` | Auto-scale images to random size in range `[512, image_sample_size]` | - | +| `--resume_from_checkpoint` | Resume training from checkpoint path, use `"latest"` to auto-select latest | None | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | + + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. ```sh export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" @@ -69,7 +246,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -83,7 +260,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_qwenimage" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -95,16 +272,21 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --trainable_modules "." ``` +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 + DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. DeepSpeed Zero-3: After training, you can use the following command to get the final model: + ```sh python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization ``` -Training shell command is as follows: +Training shell command: ```sh export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" export DATASET_NAME="datasets/internal_datasets/" @@ -140,7 +322,9 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --trainable_modules "." ``` -With FSDP: +#### 3.5.2 Training Without DeepSpeed or FSDP + +**This approach is not recommended as it lacks VRAM-saving backends and may easily cause out-of-memory errors**. This is provided for reference only. ```sh export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" @@ -151,7 +335,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train.py \ +accelerate launch --mixed_precision="bf16" scripts/qwenimage/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -165,7 +349,56 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_qwenimage" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assuming 2 machines with 8 GPUs each: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_qwenimage" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -175,4 +408,147 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --enable_bucket \ --uniform_sampling \ --trainable_modules "." -``` \ No newline at end of file +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +## 4. Inference Testing + +### 4.1 Inference Parameter Parsing + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | VRAM management mode, see table below for options | `model_group_offload` | +| `ulysses_degree` | Head dimension parallelism degree, set to 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, set to 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer during multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder during multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `enable_teacache` | Enable TeaCache for faster inference | `True` | +| `teacache_threshold` | TeaCache threshold, recommended 0.05~0.30, higher is faster but may reduce quality | 0.25 | +| `num_skip_start_steps` | Number of steps to skip at inference start, reduces impact on generation quality | 5 | +| `teacache_offload` | Offload TeaCache tensors to CPU to save VRAM | `False` | +| `cfg_skip_ratio` | Skip some CFG steps for faster inference, recommended 0.00~0.25 | 0 | +| `model_name` | Model path | `models/Diffusion_Transformer/Qwen-Image` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to load trained Transformer weights | `None` | +| `vae_path` | Path to load trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the generation content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 4.0 | +| `seed` | Random seed for reproducible results | 43 | +| `num_inference_steps` | Number of inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Path to save generated images | `samples/qwenimage-t2i` | + +**VRAM Management Mode Description**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Sequential layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +#### Quick Start + +Run the following command for single GPU inference: + +```bash +python examples/qwenimage/predict_t2i.py +``` + +Edit `examples/qwenimage/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, refer to the inference parameter parsing above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_group_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/Qwen-Image" +# Path to trained weights, e.g., "output_dir_qwenimage/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on generation content +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/qwenimage/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelization +ring_degree = 1 # Sequence dimension parallelization +``` + +**Configuration Principles**: +- `ulysses_degree` must evenly divide the model's number of heads +- `ring_degree` splits on sequence dimension, affecting communication overhead; avoid using it when heads can be divided + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelization | +| 8 | 8 | 1 | Head parallelization | +| 8 | 4 | 2 | Hybrid parallelization | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/qwenimage/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/qwenimage/README_TRAIN_LORA.md b/scripts/qwenimage/README_TRAIN_LORA.md index 628ec41c..e8a9e7be 100755 --- a/scripts/qwenimage/README_TRAIN_LORA.md +++ b/scripts/qwenimage/README_TRAIN_LORA.md @@ -1,42 +1,180 @@ -## Lora Training Code +# Qwen-Image LoRA Fine-Tuning Training Guide -We can choose whether to use deepspeed or fsdp in qwen-image, which can save a lot of video memory. +This document provides a complete workflow for Qwen-Image LoRA fine-tuning training, including environment configuration, data preparation, multiple distributed training strategies, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +--- -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. -- `target_name` represents the components/modules to which LoRA will be applied, separated by commas. -- `use_peft_lora` indicates whether to use the PEFT module for adding LoRA. Using this module will be more memory-efficient. -- `rank` means the dimension of the LoRA update matrices. -- `network_alpha` means the scale of the LoRA update matrices. +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. LoRA Training](#3-lora-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 LoRA-Specific Parameters](#33-lora-specific-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameter Parsing](#41-inference-parameter-parsing) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +--- + +## 1. Environment Configuration -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt ``` -Without deepspeed: +**Method 2: Manual Dependency Installation** -Training qwen-image without DeepSpeed may result in insufficient GPU memory. -```sh -export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. LoRA Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download Qwen-Image official weights +modelscope download --model Qwen/Qwen-Image --local_dir models/Diffusion_Transformer/Qwen-Image +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. # export NCCL_IB_DISABLE=1 # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/qwenimage/train_lora.py \ +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -48,7 +186,7 @@ accelerate launch --mixed_precision="bf16" scripts/qwenimage/train_lora.py \ --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir_lora" \ + --output_dir="output_dir_qwenimage_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -56,14 +194,47 @@ accelerate launch --mixed_precision="bf16" scripts/qwenimage/train_lora.py \ --vae_mini_batch=1 \ --max_grad_norm=0.05 \ --enable_bucket \ - --rank=64 \ - --network_alpha=32 \ + --rank=128 \ + --network_alpha=64 \ --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ --use_peft_lora \ --uniform_sampling ``` -With Deepspeed Zero-2: +### 3.3 LoRA-Specific Parameters + +**LoRA Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Path to pretrained model | `models/Diffusion_Transformer/Qwen-Image` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Samples per batch | 1 | +| `--image_sample_size` | Maximum training resolution, auto bucketing | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (equivalent to larger batch) | 1 | +| `--dataloader_num_workers` | DataLoader subprocesses | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate (recommended for LoRA) | 1e-04 | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed (for reproducible training) | 42 | +| `--output_dir` | Output directory | `output_dir_qwenimage_lora` | +| `--gradient_checkpointing` | Enable activation checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--enable_bucket` | Enable bucket training: trains entire images grouped by resolution without center cropping | - | +| `--uniform_sampling` | Uniform timestep sampling (recommended) | - | +| `--resume_from_checkpoint` | Resume training from checkpoint path, use `"latest"` to auto-select latest | None | +| `--rank` | Dimension of LoRA update matrices (higher rank = stronger expressiveness but more VRAM usage) | 128 | +| `--network_alpha` | Scaling factor of LoRA update matrices (typically set to half of rank) | 64 | +| `--target_name` | Components/modules to apply LoRA, separated by commas | `to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2` | +| `--use_peft_lora` | Use PEFT module for adding LoRA (more VRAM-efficient) | - | + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +> ✅ **Recommended**: FSDP has been thoroughly tested in this repository, with fewer errors and greater stability. ```sh export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" @@ -74,7 +245,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -86,7 +257,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir_lora" \ + --output_dir="output_dir_qwenimage_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -94,25 +265,28 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --vae_mini_batch=1 \ --max_grad_norm=0.05 \ --enable_bucket \ - --rank=64 \ - --network_alpha=32 \ + --rank=128 \ + --network_alpha=64 \ --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ --use_peft_lora \ --uniform_sampling ``` -DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. +### 3.5 Other Backends -It is known that DeepSpeed Zero-3 is not compatible with PEFT. +#### 3.5.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. DeepSpeed Zero-3: After training, you can use the following command to get the final model: + ```sh python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization ``` -Training shell command is as follows: +Training shell command: ```sh export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" export DATASET_NAME="datasets/internal_datasets/" @@ -134,7 +308,7 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir_lora" \ + --output_dir="output_dir_qwenimage_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -142,13 +316,19 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --vae_mini_batch=1 \ --max_grad_norm=0.05 \ --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ --uniform_sampling ``` -With FSDP: +#### 3.5.2 Training Without DeepSpeed or FSDP + +**This approach is not recommended as it lacks VRAM-saving backends and may easily cause out-of-memory errors**. This is provided for reference only. ```sh -export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP" +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -156,7 +336,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train_lora.py \ +accelerate launch --mixed_precision="bf16" scripts/qwenimage/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -168,7 +348,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir_lora" \ + --output_dir="output_dir_qwenimage_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -176,9 +356,206 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --vae_mini_batch=1 \ --max_grad_norm=0.05 \ --enable_bucket \ - --rank=64 \ - --network_alpha=32 \ + --rank=128 \ + --network_alpha=64 \ --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ --use_peft_lora \ --uniform_sampling -``` \ No newline at end of file +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assuming 2 machines with 8 GPUs each: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +--- + +## 4. Inference Testing + +### 4.1 Inference Parameter Parsing + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | VRAM management mode, see table below for options | `model_group_offload` | +| `ulysses_degree` | Head dimension parallelism degree, set to 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, set to 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer during multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder during multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `enable_teacache` | Enable TeaCache for faster inference | `True` | +| `teacache_threshold` | TeaCache threshold, recommended 0.05~0.30, higher is faster but may reduce quality | 0.25 | +| `num_skip_start_steps` | Number of steps to skip at inference start, reduces impact on generation quality | 5 | +| `teacache_offload` | Offload TeaCache tensors to CPU to save VRAM | `False` | +| `cfg_skip_ratio` | Skip some CFG steps for faster inference, recommended 0.00~0.25 | 0 | +| `model_name` | Model path | `models/Diffusion_Transformer/Qwen-Image` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to load trained Transformer weights | `None` | +| `vae_path` | Path to load trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1344, 768]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the generation content | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `" "` | +| `guidance_scale` | Guidance strength | 4.0 | +| `seed` | Random seed for reproducible results | 43 | +| `num_inference_steps` | Number of inference steps | 50 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Path to save generated images | `samples/qwenimage-t2i` | + +**VRAM Management Mode Description**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Sequential layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +#### Quick Start + +Run the following command for single GPU inference: + +```bash +python examples/qwenimage/predict_t2i.py +``` + +Edit `examples/qwenimage/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, refer to the inference parameter parsing above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_group_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/Qwen-Image" +# LoRA weights path, e.g., "output_dir_qwenimage_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA weight strength +lora_weight = 0.55 +# Write based on generation content +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/qwenimage/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelization +ring_degree = 1 # Sequence dimension parallelization +``` + +**Configuration Principles**: +- `ulysses_degree` must evenly divide the model's number of heads +- `ring_degree` splits on sequence dimension, affecting communication overhead; avoid using it when heads can be divided + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelization | +| 8 | 8 | 1 | Head parallelization | +| 8 | 4 | 2 | Hybrid parallelization | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/qwenimage/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/qwenimage/README_TRAIN_LORA_zh-CN.md b/scripts/qwenimage/README_TRAIN_LORA_zh-CN.md new file mode 100644 index 00000000..57d0aee1 --- /dev/null +++ b/scripts/qwenimage/README_TRAIN_LORA_zh-CN.md @@ -0,0 +1,561 @@ +# Qwen-Image LoRA 微调训练指南 + +本文档提供 Qwen-Image LoRA 微调训练的完整流程,包括环境配置、数据准备、多种分布式训练策略和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、LoRA 训练](#三lora-训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 LoRA 专用参数解析](#33-lora-专用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、LoRA 训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 Qwen-Image 官方权重 +modelscope download --model Qwen/Qwen-Image --local_dir models/Diffusion_Transformer/Qwen-Image +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用 DeepSpeed-Zero-2 与 FSDP 方案进行训练。这里使用 DeepSpeed-Zero-2 为例配置 shell 文件。 + +本文中 DeepSpeed-Zero-2 与 FSDP 的差别在于是否对模型权重进行分片,**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA 专用参数解析 + +**LoRA 关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/Qwen-Image` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率(LoRA 推荐值) | 1e-04 | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子(可复现训练) | 42 | +| `--output_dir` | 输出目录 | `output_dir_qwenimage_lora` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--uniform_sampling` | 均匀采样 timestep(推荐启用) | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--rank` | LoRA 更新矩阵的维度(rank 越大表达能力越强,但显存占用越高) | 128 | +| `--network_alpha` | LoRA 更新矩阵的缩放系数(通常设置为 rank 的一半) | 64 | +| `--target_name` | 应用 LoRA 的组件/模块,用逗号分隔 | `to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2` | +| `--use_peft_lora` | 使用 PEFT 模块添加 LoRA(更节省显存) | - | + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +> ✅ **推荐**:FSDP 在当前仓库中经过充分测试,错误更少、更稳定。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练 Shell 用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_qwenimage_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,img_mod.1,txt_mod.1,img_mlp.0,img_mlp.2,txt_mlp.0,txt_mlp.2" \ + --use_peft_lora \ + --uniform_sampling +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +--- + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_group_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `enable_teacache` | 启用 TeaCache 加速推理 | `True` | +| `teacache_threshold` | TeaCache 阈值,建议 0.05~0.30,越大越快但质量可能下降 | 0.25 | +| `num_skip_start_steps` | 推理开始跳过的步数,减少对生成质量的影响 | 5 | +| `teacache_offload` | 将 TeaCache 张量卸载到 CPU 节省显存 | `False` | +| `cfg_skip_ratio` | 跳过部分 CFG 步数加速推理,建议 0.00~0.25 | 0 | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/Qwen-Image` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `" "` | +| `guidance_scale` | 引导强度 | 4.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/qwenimage-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +#### 快速开始 + +单卡推理运行如下命令: + +```bash +python examples/qwenimage/predict_t2i.py +``` + +根据需求修改编辑 `examples/qwenimage/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_group_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Qwen-Image" +# LoRA 权重路径,如 "output_dir_qwenimage_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA 权重强度 +lora_weight = 0.55 +# 根据生成内容编写 +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/qwenimage/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/qwenimage/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun \ No newline at end of file diff --git a/scripts/qwenimage/README_TRAIN_zh-CN.md b/scripts/qwenimage/README_TRAIN_zh-CN.md new file mode 100644 index 00000000..6ec60c00 --- /dev/null +++ b/scripts/qwenimage/README_TRAIN_zh-CN.md @@ -0,0 +1,554 @@ +# Qwen-Image 全量参数训练指南 + +本文档提供 Qwen-Image Diffusion Transformer 全量参数训练的完整流程,包括环境配置、数据准备、分布式训练和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、全量参数训练](#三全量参数训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、全量参数训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 Qwen-Image 官方权重 +modelscope download --model Qwen/Qwen-Image --local_dir models/Diffusion_Transformer/Qwen-Image +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用DeepSpeed-Zero-2与FSDP方案进行训练。这里使用DeepSpeed-Zero-2为例配置shell文件。 + +本文中DeepSpeed-Zero-2与FSDP的差别在于是否对模型权重进行分片,**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_qwenimage" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/Qwen-Image` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率 | 2e-05 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码时的迷你批次大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--random_hw_adapt` | 自动缩放图片到 `[512, image_sample_size]` 范围内的随机尺寸 | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--uniform_sampling` | 均匀采样 timestep | - | +| `--trainable_modules` | 可训练模块(`"."` 表示所有模块) | `"."` | + + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=QwenImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_qwenimage" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练Shell用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_qwenimage" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/qwenimage/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_qwenimage" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Qwen-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_group_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `enable_teacache` | 启用 TeaCache 加速推理 | `True` | +| `teacache_threshold` | TeaCache 阈值,建议 0.05~0.30,越大越快但质量可能下降 | 0.25 | +| `num_skip_start_steps` | 推理开始跳过的步数,减少对生成质量的影响 | 5 | +| `teacache_offload` | 将 TeaCache 张量卸载到 CPU 节省显存 | `False` | +| `cfg_skip_ratio` | 跳过部分 CFG 步数加速推理,建议 0.00~0.25 | 0 | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/Qwen-Image` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1344, 768]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `" "` | +| `guidance_scale` | 引导强度 | 4.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 50 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/qwenimage-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +#### 快速开始 + +单卡推理运行如下命令: + +```bash +python examples/qwenimage/predict_t2i.py +``` + +根据需求修改编辑 `examples/qwenimage/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_group_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Qwen-Image" +# 训练好的权重路径,如 "output_dir_qwenimage/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据生成内容编写 +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/qwenimage/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/qwenimage/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun \ No newline at end of file diff --git a/scripts/z_image/README_TRAIN.md b/scripts/z_image/README_TRAIN.md index 77686bae..8518395a 100644 --- a/scripts/z_image/README_TRAIN.md +++ b/scripts/z_image/README_TRAIN.md @@ -1,30 +1,249 @@ -## Training Code +# Z-Image Full Parameter Training Guide -We can choose whether to use deepspeed or fsdp in z_image, which can save a lot of video memory. +This document provides a complete workflow for full parameter training of Z-Image Diffusion Transformer, including environment configuration, data preparation, distributed training, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +> **Note**: Z-Image has two model variants: `Z-Image` (standard) and `Z-Image-Turbo` (fast inference). This guide uses `Z-Image` as the default. For `Z-Image-Turbo`, simply replace the model path accordingly. -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. +--- -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Full Parameter Training](#3-full-parameter-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameters](#41-inference-parameters) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) + +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. Full Parameter Training -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download Z-Image official weights +modelscope download --model PAI/Z-Image --local_dir models/Diffusion_Transformer/Z-Image + +# (Optional) Download Z-Image-Turbo for fast inference +modelscope download --model PAI/Z-Image-Turbo --local_dir models/Diffusion_Transformer/Z-Image-Turbo +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." ``` -Without deepspeed: +### 3.3 Common Training Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Path to pretrained model | `models/Diffusion_Transformer/Z-Image` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Samples per batch | 1 | +| `--image_sample_size` | Maximum training resolution, auto bucketing | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (equivalent to larger batch) | 1 | +| `--dataloader_num_workers` | DataLoader subprocesses | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate | 2e-05 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir_z_image` | +| `--gradient_checkpointing` | Enable activation checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon value | 1e-10 | +| `--vae_mini_batch` | Mini-batch size for VAE encoding | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training: trains entire images grouped by resolution without center cropping | - | +| `--random_hw_adapt` | Auto-scale images to random size in range `[512, image_sample_size]` | - | +| `--resume_from_checkpoint` | Resume training from checkpoint path, use `"latest"` to auto-select latest | None | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | + + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. -Training z_image without DeepSpeed may result in insufficient GPU memory. ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -32,12 +251,12 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/z_image/train.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ --train_batch_size=1 \ - --image_sample_size=1024 \ + --image_sample_size=1328 \ --gradient_accumulation_steps=1 \ --dataloader_num_workers=8 \ --num_train_epochs=100 \ @@ -46,7 +265,7 @@ accelerate launch --mixed_precision="bf16" scripts/z_image/train.py \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -58,10 +277,23 @@ accelerate launch --mixed_precision="bf16" scripts/z_image/train.py \ --trainable_modules "." ``` -With Deepspeed Zero-2: +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: +After training, you can use the following command to get the final model: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command: ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -69,12 +301,12 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ --train_batch_size=1 \ - --image_sample_size=1024 \ + --image_sample_size=1328 \ --gradient_accumulation_steps=1 \ --dataloader_num_workers=8 \ --num_train_epochs=100 \ @@ -83,7 +315,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -95,10 +327,12 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --trainable_modules "." ``` -With FSDP: +#### 3.5.2 Training Without DeepSpeed or FSDP + +**This approach is not recommended as it lacks VRAM-saving backends and may easily cause out-of-memory errors**. This is provided for reference only. ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -106,12 +340,61 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train.py \ +accelerate launch --mixed_precision="bf16" scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assuming 2 machines with 8 GPUs each: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ --train_batch_size=1 \ - --image_sample_size=1024 \ + --image_sample_size=1328 \ --gradient_accumulation_steps=1 \ --dataloader_num_workers=8 \ --num_train_epochs=100 \ @@ -120,7 +403,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -130,4 +413,164 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --enable_bucket \ --uniform_sampling \ --trainable_modules "." -``` \ No newline at end of file +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +## 4. Inference Testing + +### 4.1 Inference Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory mode, see table below for options | `model_cpu_offload` | +| `ulysses_degree` | Head dimension parallelization degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelization degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer to accelerate inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/Z-Image` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1728, 992]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content to generate | `"a young girl..."` | +| `negative_prompt` | Negative prompt for content to avoid | `"low resolution, low quality..."` | +| `guidance_scale` | Guidance strength, recommended 0.0 for Turbo model | 4.0 / 0.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps, can be greatly reduced for Turbo model | 25 / 9 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/z-image-t2i` | + +**GPU Memory Mode Description**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer group offload between CPU/CUDA | Low | +| `sequential_cpu_offload` | Offload each layer individually (slowest) | Lowest | + +### 4.2 Single GPU Inference + +#### Z-Image (Standard) + +Run single GPU inference with: + +```bash +python examples/z_image/predict_t2i.py +``` + +Edit `examples/z_image/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the Inference Parameters section above. + +```python +# Choose based on your GPU VRAM +GPU_memory_mode = "model_cpu_offload" +# Your actual model path +model_name = "models/Diffusion_Transformer/Z-Image" +# Trained weights path, e.g. "output_dir_z_image/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on content to generate +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +#### Z-Image-Turbo (Fast) + +Run single GPU inference with: + +```bash +python examples/z_image/predict_turbo_t2i.py +``` + +Edit `examples/z_image/predict_turbo_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the Inference Parameters section above. + +```python +# Choose based on your GPU VRAM +GPU_memory_mode = "model_cpu_offload" +# Your actual model path +model_name = "models/Diffusion_Transformer/Z-Image-Turbo" +# Trained weights path, e.g. "output_dir_z_image_turbo/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on content to generate +prompt = "a young girl with flowing long hair, wearing a white halter dress" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/z_image/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelization +ring_degree = 1 # Sequence dimension parallelization +``` + +**Configuration Principles**: +- `ulysses_degree` must evenly divide the model's number of heads +- `ring_degree` splits on sequence dimension, affecting communication overhead; avoid using it when heads can be divided + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelization | +| 8 | 8 | 1 | Head parallelization | +| 8 | 4 | 2 | Hybrid parallelization | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/z_image/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/z_image/README_TRAIN_LORA.md b/scripts/z_image/README_TRAIN_LORA.md index 11bed67b..aab56f51 100644 --- a/scripts/z_image/README_TRAIN_LORA.md +++ b/scripts/z_image/README_TRAIN_LORA.md @@ -1,34 +1,248 @@ -## Lora Training Code +# Z-Image LoRA Fine-Tuning Training Guide -We can choose whether to use deepspeed or fsdp in z_image, which can save a lot of video memory. +This document provides a complete workflow for Z-Image LoRA fine-tuning training, including environment configuration, data preparation, multiple distributed training strategies, and inference testing. -Some parameters in the sh file can be confusing, and they are explained in this document: +> **Note**: Z-Image has two model variants: `Z-Image` (standard version) and `Z-Image-Turbo` (fast inference version). This guide uses `Z-Image` by default. To use `Z-Image-Turbo`, simply replace the model path accordingly. -- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images at the center, but instead, it trains the entire images after grouping them into buckets based on resolution. -- `random_hw_adapt` is used to enable automatic height and width scaling for images. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `512` as the minimum. - - For example, when `random_hw_adapt` is enabled, `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024` -- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint. -- `target_name` represents the components/modules to which LoRA will be applied, separated by commas. -- `use_peft_lora` indicates whether to use the PEFT module for adding LoRA. Using this module will be more memory-efficient. -- `rank` means the dimension of the LoRA update matrices. -- `network_alpha` means the scale of the LoRA update matrices. +--- -When train model with multi machines, please set the params as follows: -```sh -export MASTER_ADDR="your master address" -export MASTER_PORT=10086 -export WORLD_SIZE=1 # The number of machines -export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8 -export RANK=0 # The rank of this machine +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. LoRA Training](#3-lora-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 LoRA-Specific Parameters](#33-lora-specific-parameters) + - [3.4 Training with FSDP](#34-training-with-fsdp) + - [3.5 Other Backends](#35-other-backends) + - [3.6 Multi-Machine Distributed Training](#36-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameter Parsing](#41-inference-parameter-parsing) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) + +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure -accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/xxx/xxx.py ``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. LoRA Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer -Without deepspeed: +# Download Z-Image official weights +modelscope download --model Tongyi-MAI/Z-Image --local_dir models/Diffusion_Transformer/Z-Image + +# (Optional) Download Z-Image-Turbo fast inference version +modelscope download --model Tongyi-MAI/Z-Image-Turbo --local_dir models/Diffusion_Transformer/Z-Image-Turbo +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA-Specific Parameters + +**LoRA Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Path to pretrained model | `models/Diffusion_Transformer/Z-Image` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Samples per batch | 1 | +| `--image_sample_size` | Maximum training resolution, auto bucketing | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (equivalent to larger batch) | 1 | +| `--dataloader_num_workers` | DataLoader subprocesses | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate (recommended for LoRA) | 1e-04 | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed (for reproducible training) | 42 | +| `--output_dir` | Output directory | `output_dir_z_image_lora` | +| `--gradient_checkpointing` | Enable activation checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--enable_bucket` | Enable bucket training: trains entire images grouped by resolution without center cropping | - | +| `--uniform_sampling` | Uniform timestep sampling (recommended) | - | +| `--resume_from_checkpoint` | Resume training from checkpoint path, use `"latest"` to auto-select latest | None | +| `--rank` | Dimension of LoRA update matrices (higher rank = stronger expressiveness but more VRAM usage) | 64 | +| `--network_alpha` | Scaling factor of LoRA update matrices (typically set to half of rank or same) | 64 | +| `--target_name` | Components/modules to apply LoRA, separated by commas | `to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3` | +| `--use_peft_lora` | Use PEFT module for adding LoRA (more VRAM-efficient) | - | + +### 3.4 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +> ✅ **Recommended**: FSDP has been thoroughly tested in this repository, with fewer errors and greater stability. -Training z_image without DeepSpeed may result in insufficient GPU memory. ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -36,7 +250,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -48,7 +262,7 @@ accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora.py \ --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -57,16 +271,29 @@ accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora.py \ --max_grad_norm=0.05 \ --enable_bucket \ --rank=64 \ - --network_alpha=32 \ - --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ --use_peft_lora \ --uniform_sampling ``` -With Deepspeed Zero-2: +### 3.5 Other Backends + +#### 3.5.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3: +After training, you can use the following command to get the final model: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command: ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -74,7 +301,7 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -86,7 +313,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -95,16 +322,18 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --max_grad_norm=0.05 \ --enable_bucket \ --rank=64 \ - --network_alpha=32 \ - --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ --use_peft_lora \ --uniform_sampling ``` -With FSDP: +#### 3.5.2 Training Without DeepSpeed or FSDP + +**This approach is not recommended as it lacks VRAM-saving backends and may easily cause out-of-memory errors**. This is provided for reference only. ```sh -export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Turbo" +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" export DATASET_NAME="datasets/internal_datasets/" export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. @@ -112,7 +341,57 @@ export DATASET_META_NAME="datasets/internal_datasets/metadata.json" # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train_lora.py \ +accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 Multi-Machine Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.6.1 Environment Configuration + +Assuming 2 machines with 8 GPUs each: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ @@ -124,7 +403,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir" \ + --output_dir="output_dir_z_image_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ @@ -133,8 +412,174 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --max_grad_norm=0.05 \ --enable_bucket \ --rank=64 \ - --network_alpha=32 \ - --target_name="to_q,to_k,to_v,ff.0,ff.2,ff_context.0,ff_context.2" \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ --use_peft_lora \ --uniform_sampling -``` \ No newline at end of file +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.6.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +--- + +## 4. Inference Testing + +### 4.1 Inference Parameter Parsing + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | VRAM management mode, see table below for options | `model_cpu_offload` | +| `ulysses_degree` | Head dimension parallelism degree, set to 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelism degree, set to 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer during multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder during multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/Z-Image` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to load trained Transformer weights | `None` | +| `vae_path` | Path to load trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1728, 992]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the generation content | `"A young woman..."` | +| `negative_prompt` | Negative prompt for content to avoid | `"low resolution, low quality..."` | +| `guidance_scale` | Guidance strength, recommended 0.0 for Turbo model | 4.0 / 0.0 | +| `seed` | Random seed for reproducible results | 43 | +| `num_inference_steps` | Number of inference steps, can be greatly reduced for Turbo model | 25 / 9 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Path to save generated images | `samples/z-image-t2i` | + +**VRAM Management Mode Description**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Sequential layer offload (slowest) | Lowest | + +### 4.2 Single GPU Inference + +#### Z-Image (Standard Version) + +Run the following command for single GPU inference: + +```bash +python examples/z_image/predict_t2i.py +``` + +Edit `examples/z_image/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, refer to the inference parameter parsing above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_cpu_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/Z-Image" +# LoRA weights path, e.g., "output_dir_z_image_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA weight strength +lora_weight = 0.55 +# Write based on generation content +prompt = "A young woman standing on a sunny coastline, her white dress gently fluttering in the sea breeze." +# ... +``` + +#### Z-Image-Turbo (Fast Version) + +Run the following command for single GPU inference: + +```bash +python examples/z_image/predict_turbo_t2i.py +``` + +Edit `examples/z_image/predict_turbo_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, refer to the inference parameter parsing above. + +```python +# Choose based on GPU VRAM +GPU_memory_mode = "model_cpu_offload" +# Based on actual model path +model_name = "models/Diffusion_Transformer/Z-Image-Turbo" +# LoRA weights path, e.g., "output_dir_z_image_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA weight strength +lora_weight = 0.55 +# Write based on generation content +prompt = "A young woman standing on a sunny coastline, her white dress gently fluttering in the sea breeze." +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/z_image/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelization +ring_degree = 1 # Sequence dimension parallelization +``` + +**Configuration Principles**: +- `ulysses_degree` must evenly divide the model's number of heads +- `ring_degree` splits on sequence dimension, affecting communication overhead; avoid using it when heads can be divided + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelization | +| 8 | 8 | 1 | Head parallelization | +| 8 | 4 | 2 | Hybrid parallelization | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/z_image/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/z_image/README_TRAIN_LORA_zh-CN.md b/scripts/z_image/README_TRAIN_LORA_zh-CN.md new file mode 100644 index 00000000..5f59af35 --- /dev/null +++ b/scripts/z_image/README_TRAIN_LORA_zh-CN.md @@ -0,0 +1,585 @@ +# Z-Image LoRA 微调训练指南 + +本文档提供 Z-Image LoRA 微调训练的完整流程,包括环境配置、数据准备、多种分布式训练策略和推理测试。 + +> **说明**:Z-Image 有两个模型变体:`Z-Image`(标准版)和 `Z-Image-Turbo`(快速推理版)。本指南默认使用 `Z-Image`,如需使用 `Z-Image-Turbo`,替换对应的模型路径即可。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、LoRA 训练](#三lora-训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 LoRA 专用参数解析](#33-lora-专用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、LoRA 训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 Z-Image 官方权重 +modelscope download --model Tongyi-MAI/Z-Image --local_dir models/Diffusion_Transformer/Z-Image + +# (可选)下载 Z-Image-Turbo 快速推理版 +modelscope download --model Tongyi-MAI/Z-Image-Turbo --local_dir models/Diffusion_Transformer/Z-Image-Turbo +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用 DeepSpeed-Zero-2 与 FSDP 方案进行训练。这里使用 DeepSpeed-Zero-2 为例配置 shell 文件。 + +本文中 DeepSpeed-Zero-2 与 FSDP 的差别在于是否对模型权重进行分片,**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.3 LoRA 专用参数解析 + +**LoRA 关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/Z-Image` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率(LoRA 推荐值) | 1e-04 | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子(可复现训练) | 42 | +| `--output_dir` | 输出目录 | `output_dir_z_image_lora` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--uniform_sampling` | 均匀采样 timestep(推荐启用) | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--rank` | LoRA 更新矩阵的维度(rank 越大表达能力越强,但显存占用越高) | 64 | +| `--network_alpha` | LoRA 更新矩阵的缩放系数(通常设置为 rank 的一半或相同) | 64 | +| `--target_name` | 应用 LoRA 的组件/模块,用逗号分隔 | `to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3` | +| `--use_peft_lora` | 使用 PEFT 模块添加 LoRA(更节省显存) | - | + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用 DeepSpeed-Zero-2 的情况下显存不足**,可以切换使用 FSDP 进行训练。 + +> ✅ **推荐**:FSDP 在当前仓库中经过充分测试,错误更少、更稳定。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练 Shell 用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=64 \ + --network_alpha=64 \ + --target_name="to_q,to_k,to_v,feed_forward.w1,feed_forward.w2,feed_forward.w3" \ + --use_peft_lora \ + --uniform_sampling +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +--- + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/Z-Image` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1728, 992]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"一位年轻女子..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `"低分辨率,低画质..."` | +| `guidance_scale` | 引导强度,Turbo 模型建议设为 0.0 | 4.0 / 0.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数,Turbo 模型可大幅减少 | 25 / 9 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/z-image-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +#### Z-Image(标准版) + +单卡推理运行如下命令: + +```bash +python examples/z_image/predict_t2i.py +``` + +根据需求修改编辑 `examples/z_image/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Z-Image" +# LoRA 权重路径,如 "output_dir_z_image_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA 权重强度 +lora_weight = 0.55 +# 根据生成内容编写 +prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。" +# ... +``` + +#### Z-Image-Turbo(快速版) + +单卡推理运行如下命令: + +```bash +python examples/z_image/predict_turbo_t2i.py +``` + +根据需求修改编辑 `examples/z_image/predict_turbo_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Z-Image-Turbo" +# LoRA 权重路径,如 "output_dir_z_image_lora/checkpoint-xxx/lora_weights.safetensors" +lora_path = None +# LoRA 权重强度 +lora_weight = 0.55 +# 根据生成内容编写 +prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/z_image/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/z_image/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/z_image/README_TRAIN_zh-CN.md b/scripts/z_image/README_TRAIN_zh-CN.md new file mode 100644 index 00000000..492b1839 --- /dev/null +++ b/scripts/z_image/README_TRAIN_zh-CN.md @@ -0,0 +1,576 @@ +# Z-Image 全量参数训练指南 + +本文档提供 Z-Image Diffusion Transformer 全量参数训练的完整流程,包括环境配置、数据准备、分布式训练和推理测试。 + +> **说明**:Z-Image 有两个模型变体:`Z-Image`(标准版)和 `Z-Image-Turbo`(快速推理版)。本指南默认使用 `Z-Image`,如需使用 `Z-Image-Turbo`,替换对应的模型路径即可。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、全量参数训练](#三全量参数训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 使用 FSDP 训练](#34-使用-fsdp-训练) + - [3.5 其他后端](#35-其他后端) + - [3.6 多机分布式训练](#36-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、全量参数训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 Z-Image 官方权重 +modelscope download --model Tongyi-MAI/Z-Image --local_dir models/Diffusion_Transformer/Z-Image + +# (可选)下载 Z-Image-Turbo 快速推理版 +modelscope download --model Tongyi-MAI/Z-Image-Turbo --local_dir models/Diffusion_Transformer/Z-Image-Turbo +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用DeepSpeed-Zero-2与FSDP方案进行训练。这里使用DeepSpeed-Zero-2为例配置shell文件。 + +本文中DeepSpeed-Zero-2与FSDP的差别在于是否对模型权重进行分片,**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/Z-Image` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率 | 2e-05 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir_z_image` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码时的迷你批次大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--random_hw_adapt` | 自动缩放图片到 `[512, image_sample_size]` 范围内的随机尺寸 | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--uniform_sampling` | 均匀采样 timestep | - | +| `--trainable_modules` | 可训练模块(`"."` 表示所有模块) | `"."` | + + +### 3.4 使用 FSDP 训练 + +**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ZImageTransformerBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.5 其他后端 + +#### 3.5.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3: + +训练完成后,您可以使用以下命令获取最终模型: + +```sh +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +执行命令为: +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +#### 3.5.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练Shell用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.6.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/z_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Z-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.6.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/Z-Image` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1728, 992]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"一位年轻女子..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `"低分辨率,低画质..."` | +| `guidance_scale` | 引导强度,Turbo 模型建议设为 0.0 | 4.0 / 0.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数,Turbo 模型可大幅减少 | 25 / 9 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/z-image-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +#### Z-Image(标准版) + +单卡推理运行如下命令: + +```bash +python examples/z_image/predict_t2i.py +``` + +根据需求修改编辑 `examples/z_image/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Z-Image-Turbo" +# 训练好的权重路径,如 "output_dir_z_image/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据生成内容编写 +prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。" +# ... +``` + +#### Z-Image-Turbo(快速版) + +单卡推理运行如下命令: + +```bash +python examples/z_image/predict_turbo_t2i.py +``` + +根据需求修改编辑 `examples/z_image/predict_turbo_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Z-Image-Turbo" +# 训练好的权重路径,如 "output_dir_z_image/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据生成内容编写 +prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/z_image/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/z_image/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/videox_fun/dist/__init__.py b/videox_fun/dist/__init__.py index e69409e2..1fb11c84 100755 --- a/videox_fun/dist/__init__.py +++ b/videox_fun/dist/__init__.py @@ -13,6 +13,8 @@ xFuserLongContextAttention) from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0 from .infinitalk_xfuser import usp_attn_infinitetalk_forward +from .ltx2_xfuser import (LTX2MultiGPUsAttnProcessor, + LTX2PerturbedMultiGPUsAttnProcessor) from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0 from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor diff --git a/videox_fun/dist/ltx2_xfuser.py b/videox_fun/dist/ltx2_xfuser.py new file mode 100644 index 00000000..f825bc6f --- /dev/null +++ b/videox_fun/dist/ltx2_xfuser.py @@ -0,0 +1,375 @@ +# Copyright 2025 The VideoX-Fun Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-GPU sequence parallel attention processors for LTX2 transformer. +""" + +from typing import Tuple + +import torch + +from .fuser import xFuserLongContextAttention +from ..models.attention_utils import attention + + +class LTX2MultiGPUsAttnProcessor: + """ + Multi-GPU sequence parallel attention processor for LTX2. + Uses xFuserLongContextAttention for distributed attention computation. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "LTX2MultiGPUsAttnProcessor requires PyTorch 2.0 or later. " + "Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states=None, + attention_mask=None, + query_rotary_emb=None, + key_rotary_emb=None, + perturbation_mask=None, + all_perturbed=None, + ) -> torch.Tensor: + # Get sequence parallel info from attn module + all_gather = getattr(attn, 'all_gather', None) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + # Determine if this is self-attention or cross-attention + is_self_attn = encoder_hidden_states is None + + if is_self_attn: + # Self-attention: use hidden_states for both Q and KV + encoder_hidden_states = hidden_states + + # Calculate gate logits on original hidden_states if needed + if attn.to_gate_logits is not None: + gate_logits = attn.to_gate_logits(hidden_states) + + # Project to Q, K, V + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Apply normalization + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Apply rotary embeddings if provided + # Note: RoPE freqs are already generated for the chunked sequence in transformer forward + if query_rotary_emb is not None: + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + + if attn.rope_type == "interleaved": + # freqs shape: [B, S_chunked, D] + cos, sin = query_rotary_emb + + # Apply interleaved RoPE + query_real, query_imag = query.unflatten(3, (-1, 2)).unbind(-1) + query_rotated = torch.stack([-query_imag, query_real], dim=-1).flatten(3) + query = (query.float() * cos.unsqueeze(2) + query_rotated.float() * sin.unsqueeze(2)).to(query.dtype) + + # Same for key (fall back to query_rotary_emb for self-attention) + cos_k, sin_k = key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + key_real, key_imag = key.unflatten(3, (-1, 2)).unbind(-1) + key_rotated = torch.stack([-key_imag, key_real], dim=-1).flatten(3) + key = (key.float() * cos_k.unsqueeze(2) + key_rotated.float() * sin_k.unsqueeze(2)).to(key.dtype) + + elif attn.rope_type == "split": + # Apply split RoPE using the same logic as apply_split_rotary_emb + # x: [B, S, H, D], freqs: [B, H, S, D//2] + cos, sin = query_rotary_emb + + # Save original dtype + query_dtype = query.dtype + + # Reshape query to match freqs dimensions + b, s, h, d = query.shape + # cos is (b, h, s, d//2) -> reshape query to (b, h, s, d) + query = query.reshape(b, s, h, -1).transpose(1, 2) # [B, H, S, D] + + # Split last dim into pairs + r = d // 2 + split_query = query.reshape(b, h, s, 2, r).float() # [B, H, S, 2, r] + first_x = split_query[..., :1, :] # [B, H, S, 1, r] + second_x = split_query[..., 1:, :] # [B, H, S, 1, r] + + # Apply rotation + cos_u = cos.unsqueeze(-2) # [B, H, S, 1, r//2] + sin_u = sin.unsqueeze(-2) + + first_out = first_x * cos_u - second_x * sin_u + second_out = second_x * cos_u + first_x * sin_u + + query = torch.cat([first_out, second_out], dim=-2).reshape(b, h, s, d) + query = query.transpose(1, 2).reshape(b, s, h, d).to(query_dtype) # [B, S, H, D] + + # Same for key (fall back to query_rotary_emb for self-attention) + cos_k, sin_k = key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + key_dtype = key.dtype + b_k, s_k, h_k, d_k = key.shape + key = key.reshape(b_k, s_k, h_k, -1).transpose(1, 2) # [B, H, S, D] + + r_k = d_k // 2 + split_key = key.reshape(b_k, h_k, s_k, 2, r_k).float() + first_k = split_key[..., :1, :] + second_k = split_key[..., 1:, :] + + cos_k_u = cos_k.unsqueeze(-2) + sin_k_u = sin_k.unsqueeze(-2) + + first_k_out = first_k * cos_k_u - second_k * sin_k_u + second_k_out = second_k * cos_k_u + first_k * sin_k_u + + key = torch.cat([first_k_out, second_k_out], dim=-2).reshape(b_k, h_k, s_k, d_k) + key = key.transpose(1, 2).reshape(b_k, s_k, h_k, d_k).to(key_dtype) + else: + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + + value = value.unflatten(2, (attn.heads, -1)) + + # Use xFuserLongContextAttention for distributed attention + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + if is_self_attn: + # Self-attention: Q, K, V are all chunked, use xFuser for communication + hidden_states = xFuserLongContextAttention()( + None, + half(query), half(key), half(value), + dropout_p=0.0, + causal=False, + ) + else: + # Video-to-audio cross-attention: Q=audio(full), K,V=video(chunked). + # Need to all_gather K,V across ranks before attention. + if all_gather is not None: + key = all_gather(key.contiguous(), dim=1) + value = all_gather(value.contiguous(), dim=1) + + # Regular attention with [B, S, H, D] layout + hidden_states = attention( + half(query), + half(key), + half(value), + dropout_p=0.0, + ) # [B, S_q, H, D] + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Apply gating if present + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + + # Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LTX2PerturbedMultiGPUsAttnProcessor: + """ + Multi-GPU sequence parallel attention processor with perturbation support for LTX2. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "LTX2PerturbedMultiGPUsAttnProcessor requires PyTorch 2.0 or later. " + "Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states=None, + attention_mask=None, + query_rotary_emb=None, + key_rotary_emb=None, + perturbation_mask=None, + all_perturbed=None, + ) -> torch.Tensor: + # Get sequence parallel info + all_gather = getattr(attn, 'all_gather', None) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + is_self_attn = encoder_hidden_states is None + + if is_self_attn: + encoder_hidden_states = hidden_states + + # Calculate gate logits + if attn.to_gate_logits is not None: + gate_logits = attn.to_gate_logits(hidden_states) + + value = attn.to_v(encoder_hidden_states) + + # Check if all tokens are perturbed + if all_perturbed is None: + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + + if all_perturbed: + # Skip attention, use value directly + hidden_states = value + else: + # Project to Q, K + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + + # Apply normalization + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Apply RoPE (RoPE freqs are already generated for chunked sequence) + if query_rotary_emb is not None: + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + + if attn.rope_type == "interleaved": + cos, sin = query_rotary_emb + + query_real, query_imag = query.unflatten(3, (-1, 2)).unbind(-1) + query_rotated = torch.stack([-query_imag, query_real], dim=-1).flatten(3) + query = (query.float() * cos.unsqueeze(2) + query_rotated.float() * sin.unsqueeze(2)).to(query.dtype) + + # Fall back to query_rotary_emb for self-attention + cos_k, sin_k = key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + key_real, key_imag = key.unflatten(3, (-1, 2)).unbind(-1) + key_rotated = torch.stack([-key_imag, key_real], dim=-1).flatten(3) + key = (key.float() * cos_k.unsqueeze(2) + key_rotated.float() * sin_k.unsqueeze(2)).to(key.dtype) + elif attn.rope_type == "split": + # Apply split RoPE + cos, sin = query_rotary_emb + + # Save original dtype + query_dtype = query.dtype + + b, s, h, d = query.shape + query = query.reshape(b, s, h, -1).transpose(1, 2) # [B, H, S, D] + + r = d // 2 + split_query = query.reshape(b, h, s, 2, r).float() + first_x = split_query[..., :1, :] + second_x = split_query[..., 1:, :] + + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + first_out = first_x * cos_u - second_x * sin_u + second_out = second_x * cos_u + first_x * sin_u + + query = torch.cat([first_out, second_out], dim=-2).reshape(b, h, s, d) + query = query.transpose(1, 2).reshape(b, s, h, d).to(query_dtype) + + # Fall back to query_rotary_emb for self-attention + cos_k, sin_k = key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + key_dtype = key.dtype + b_k, s_k, h_k, d_k = key.shape + key = key.reshape(b_k, s_k, h_k, -1).transpose(1, 2) + + r_k = d_k // 2 + split_key = key.reshape(b_k, h_k, s_k, 2, r_k).float() + first_k = split_key[..., :1, :] + second_k = split_key[..., 1:, :] + + cos_k_u = cos_k.unsqueeze(-2) + sin_k_u = sin_k.unsqueeze(-2) + + first_k_out = first_k * cos_k_u - second_k * sin_k_u + second_k_out = second_k * cos_k_u + first_k * sin_k_u + + key = torch.cat([first_k_out, second_k_out], dim=-2).reshape(b_k, h_k, s_k, d_k) + key = key.transpose(1, 2).reshape(b_k, s_k, h_k, d_k).to(key_dtype) + else: + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + + value = value.unflatten(2, (attn.heads, -1)) + + # Use xFuserLongContextAttention + half_dtypes = (torch.float16, torch.bfloat16) + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + if is_self_attn: + hidden_states = xFuserLongContextAttention()( + None, + half(query), half(key), half(value), + dropout_p=0.0, + causal=False, + ) + else: + # Video-to-audio cross-attention: Q=audio(full), K,V=video(chunked). + # Need to all_gather K,V across ranks before attention. + if all_gather is not None: + key = all_gather(key.contiguous(), dim=1) + value = all_gather(value.contiguous(), dim=1) + + # Regular attention with [B, S, H, D] layout + hidden_states = attention( + half(query), + half(key), + half(value), + dropout_p=0.0, + ) # [B, S_q, H, D] + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # Apply perturbation masking + if perturbation_mask is not None: + value = value.flatten(2, 3) + hidden_states = torch.lerp(value, hidden_states, perturbation_mask) + + # Apply gating + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) + gates = 2.0 * torch.sigmoid(gate_logits) + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + + # Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 23eed907..b040d1bb 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -4,9 +4,9 @@ from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, - Gemma3ForConditionalGeneration, GemmaTokenizer, - GemmaTokenizerFast, LlamaModel, LlamaTokenizerFast, - LlavaForConditionalGeneration, + Gemma3ForConditionalGeneration, Gemma3Processor, + GemmaTokenizer, GemmaTokenizerFast, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration, Mistral3ForConditionalGeneration, PixtralProcessor, Qwen3Config, Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer, T5TokenizerFast, UMT5EncoderModel, @@ -31,8 +31,8 @@ from .cogvideox_vae import AutoencoderKLCogVideoX from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel -from .flashhead_transformer3d import FlashHeadTransformer3DModel from .flashhead_audio_encoder import FlashHeadAudioEncoder +from .flashhead_transformer3d import FlashHeadTransformer3DModel from .flux2_image_processor import Flux2ImageProcessor from .flux2_transformer2d import Flux2Transformer2DModel from .flux2_transformer2d_control import Flux2ControlTransformer2DModel @@ -52,7 +52,7 @@ from .ltx2_transformer3d import LTX2VideoTransformer3DModel from .ltx2_vae import AutoencoderKLLTX2Video from .ltx2_vae_audio import AutoencoderKLLTX2Audio -from .ltx2_vocoder import LTX2Vocoder +from .ltx2_vocoder import LTX2Vocoder, LTX2VocoderWithBWE from .mova_audio_transformer3d import WanAudioTransformer3DModel from .mova_interactionv2 import MOVADualTowerConditionalBridge from .mova_model import MOVAModel diff --git a/videox_fun/models/ltx2_connecter.py b/videox_fun/models/ltx2_connecter.py index 985583f9..f3ee0a54 100644 --- a/videox_fun/models/ltx2_connecter.py +++ b/videox_fun/models/ltx2_connecter.py @@ -1,4 +1,6 @@ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/connectors.py +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -9,6 +11,78 @@ from .ltx2_transformer3d import LTX2Attention, LTX2AudioVideoAttnProcessor +def per_layer_masked_mean_norm( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +): + """ + Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states. + Respects the padding of the hidden states. + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + +def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True) + norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps) + return norm_text_encoder_hidden_states + class LTX2RotaryPosEmbed1d(nn.Module): """ @@ -107,6 +181,7 @@ def __init__( activation_fn: str = "gelu-approximate", eps: float = 1e-6, rope_type: str = "interleaved", + apply_gated_attention: bool = False, ): super().__init__() @@ -116,8 +191,9 @@ def __init__( heads=num_attention_heads, kv_heads=num_attention_heads, dim_head=attention_head_dim, - processor=LTX2AudioVideoAttnProcessor(), rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + processor=LTX2AudioVideoAttnProcessor(), ) self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) @@ -161,6 +237,7 @@ def __init__( eps: float = 1e-6, causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + gated_attention: bool = False, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -189,6 +266,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, rope_type=rope_type, + apply_gated_attention=gated_attention, ) for _ in range(num_layers) ] @@ -261,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): @register_to_config def __init__( self, - caption_channels: int, - text_proj_in_factor: int, - video_connector_num_attention_heads: int, - video_connector_attention_head_dim: int, - video_connector_num_layers: int, - video_connector_num_learnable_registers: int | None, - audio_connector_num_attention_heads: int, - audio_connector_attention_head_dim: int, - audio_connector_num_layers: int, - audio_connector_num_learnable_registers: int | None, - connector_rope_base_seq_len: int, - rope_theta: float, - rope_double_precision: bool, - causal_temporal_positioning: bool, + caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size + text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B + video_connector_num_attention_heads: int = 30, + video_connector_attention_head_dim: int = 128, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int | None = 128, + video_gated_attn: bool = False, + audio_connector_num_attention_heads: int = 30, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: int | None = 128, + audio_gated_attn: bool = False, + connector_rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_temporal_positioning: bool = False, rope_type: str = "interleaved", + per_modality_projections: bool = False, + video_hidden_dim: int = 4096, + audio_hidden_dim: int = 2048, + proj_bias: bool = False, ): super().__init__() - self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + text_encoder_dim = caption_channels * text_proj_in_factor + if per_modality_projections: + self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias) + self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias) + else: + self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias) + self.video_connector = LTX2ConnectorTransformer1d( num_attention_heads=video_connector_num_attention_heads, attention_head_dim=video_connector_attention_head_dim, @@ -289,6 +379,7 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=video_gated_attn, ) self.audio_connector = LTX2ConnectorTransformer1d( num_attention_heads=audio_connector_num_attention_heads, @@ -300,26 +391,86 @@ def __init__( rope_double_precision=rope_double_precision, causal_temporal_positioning=causal_temporal_positioning, rope_type=rope_type, + gated_attention=audio_gated_attn, ) def forward( - self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False - ): - # Convert to additive attention mask, if necessary - if not additive_mask: - text_dtype = text_encoder_hidden_states.dtype - attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max - - text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) - - video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) - - attn_mask = (new_attn_mask < 1e-6).to(torch.int64) - attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * attn_mask - new_attn_mask = attn_mask.squeeze(-1) - - audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) - - return video_text_embedding, audio_text_embedding, new_attn_mask \ No newline at end of file + self, + text_encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", + scale_factor: int = 8, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text + embeddings for the LTX-2.X DiT models. + + Args: + text_encoder_hidden_states (`torch.Tensor`)): + Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len, + caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked + positions. + padding_side (`str`, *optional*, defaults to `"left"`): + The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to + `"left"` as this is what the default Gemma3-12B text encoder uses. Only used if + `per_modality_projections` is `False` (LTX-2.0 models). + scale_factor (`int`, *optional*, defaults to `8`): + Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False` + (LTX-2.0 models). + """ + if text_encoder_hidden_states.ndim == 3: + # Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor] + text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1)) + + if self.config.per_modality_projections: + # LTX-2.3 + norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states) + + norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3) + bool_mask = attention_mask.bool().unsqueeze(-1) + norm_text_encoder_hidden_states = torch.where( + bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states) + ) + + # Rescale norms with respect to video and audio dims for feature extractors + video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels) + video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor + audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels) + audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor + + # Per-Modality Feature extractors + video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb) + audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb) + else: + # LTX-2.0 + sequence_lengths = attention_mask.sum(dim=-1) + norm_text_encoder_hidden_states = per_layer_masked_mean_norm( + text_hidden_states=text_encoder_hidden_states, + sequence_lengths=sequence_lengths, + device=text_encoder_hidden_states.device, + padding_side=padding_side, + scale_factor=scale_factor, + ) + + text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states) + video_text_emb_proj = text_emb_proj + audio_text_emb_proj = text_emb_proj + + # Convert to additive attention mask for connectors + text_dtype = video_text_emb_proj.dtype + attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype) + attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + add_attn_mask = attention_mask * torch.finfo(text_dtype).max + + video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask) + + # Convert video attn mask to binary (multiplicative) mask and mask video text embedding + binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64) + binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * binary_attn_mask + + audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask) + + return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1) \ No newline at end of file diff --git a/videox_fun/models/ltx2_transformer3d.py b/videox_fun/models/ltx2_transformer3d.py index dca5b58b..72bdbc74 100644 --- a/videox_fun/models/ltx2_transformer3d.py +++ b/videox_fun/models/ltx2_transformer3d.py @@ -32,6 +32,11 @@ from diffusers.utils.outputs import BaseOutput from .attention_utils import attention +from ..dist import (LTX2MultiGPUsAttnProcessor, + LTX2PerturbedMultiGPUsAttnProcessor, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -179,6 +184,10 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -211,6 +220,110 @@ def __call__( hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2PerturbedAttnProcessor: + r""" + Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.to_gate_logits is not None: + # Calculate gate logits on original hidden_states + gate_logits = attn.to_gate_logits(hidden_states) + + value = attn.to_v(encoder_hidden_states) + if all_perturbed is None: + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + + if all_perturbed: + # Skip attention, use the value projection value + hidden_states = value + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = attention( + q=query, + k=key, + v=value, + attn_mask=attention_mask, + dropout_p=0.0, + causal=False, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if perturbation_mask is not None: + value = value.flatten(2, 3) + hidden_states = torch.lerp(value, hidden_states, perturbation_mask) + + if attn.to_gate_logits is not None: + hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D] + # The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1 + gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H] + hidden_states = hidden_states * gates.unsqueeze(-1) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states @@ -223,7 +336,7 @@ class LTX2Attention(torch.nn.Module): """ _default_processor_cls = LTX2AudioVideoAttnProcessor - _available_processors = [LTX2AudioVideoAttnProcessor] + _available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor] def __init__( self, @@ -239,6 +352,7 @@ def __init__( norm_eps: float = 1e-6, norm_elementwise_affine: bool = True, rope_type: str = "interleaved", + apply_gated_attention: bool = False, processor=None, ): super().__init__() @@ -265,10 +379,33 @@ def __init__( self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) + if apply_gated_attention: + # Per head gate values + self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) + else: + self.to_gate_logits = None + if processor is None: processor = self._default_processor_cls() self.processor = processor + def set_processor(self, processor) -> None: + """ + Set the attention processor to use. + + Args: + processor: The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: @@ -363,6 +500,10 @@ def __init__( audio_num_attention_heads: int, audio_attention_head_dim, audio_cross_attention_dim: int, + video_gated_attn: bool = False, + video_cross_attn_adaln: bool = False, + audio_gated_attn: bool = False, + audio_cross_attn_adaln: bool = False, qk_norm: str = "rms_norm_across_heads", activation_fn: str = "gelu-approximate", attention_bias: bool = True, @@ -370,9 +511,16 @@ def __init__( eps: float = 1e-6, elementwise_affine: bool = False, rope_type: str = "interleaved", + perturbed_attn: bool = False, ): super().__init__() + self.perturbed_attn = perturbed_attn + if perturbed_attn: + attn_processor_cls = LTX2PerturbedAttnProcessor + else: + attn_processor_cls = LTX2AudioVideoAttnProcessor + # 1. Self-Attention (video and audio) self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) self.attn1 = LTX2Attention( @@ -385,6 +533,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -398,6 +548,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 2. Prompt Cross-Attention @@ -412,6 +564,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) @@ -425,6 +579,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -440,6 +596,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=video_gated_attn, + processor=attn_processor_cls(), ) # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video @@ -454,6 +612,8 @@ def __init__( out_bias=attention_out_bias, qk_norm=qk_norm, rope_type=rope_type, + apply_gated_attention=audio_gated_attn, + processor=attn_processor_cls(), ) # 4. Feedforward layers @@ -464,14 +624,36 @@ def __init__( self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) # 5. Per-Layer Modulation Parameters - # Self-Attention / Feedforward AdaLayerNorm-Zero mod params - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + # Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params + # 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q + self.video_cross_attn_adaln = video_cross_attn_adaln + self.audio_cross_attn_adaln = audio_cross_attn_adaln + video_mod_param_num = 9 if self.video_cross_attn_adaln else 6 + audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6 + self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5) + + # Prompt cross-attn (attn2) additional modulation params + self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln + if self.cross_attn_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim)) # Per-layer a2v, v2a Cross-Attention mod params self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + @staticmethod + def get_mod_params( + scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.shape[1], num_ada_params, -1 + ) + ada_params = ada_values.unbind(dim=2) + return ada_params + def forward( self, hidden_states: torch.Tensor, @@ -484,143 +666,181 @@ def forward( temb_ca_audio_scale_shift: torch.Tensor, temb_ca_gate: torch.Tensor, temb_ca_audio_gate: torch.Tensor, + temb_prompt: torch.Tensor | None = None, + temb_prompt_audio: torch.Tensor | None = None, video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, + self_attention_mask: torch.Tensor | None = None, + audio_self_attention_mask: torch.Tensor | None = None, a2v_cross_attention_mask: torch.Tensor | None = None, v2a_cross_attention_mask: torch.Tensor | None = None, + use_a2v_cross_attention: bool = True, + use_v2a_cross_attention: bool = True, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool | None = None, ) -> torch.Tensor: batch_size = hidden_states.size(0) # 1. Video and Audio Self-Attention - norm_hidden_states = self.norm1(hidden_states) + # 1.1. Video Self-Attention + video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6] + if self.video_cross_attn_adaln: + shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9] - num_ada_params = self.scale_shift_table.shape[0] - ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( - batch_size, temb.size(1), num_ada_params, -1 - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=video_rotary_emb, - ) + video_self_attn_args = { + "hidden_states": norm_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": video_rotary_emb, + "attention_mask": self_attention_mask, + } + if self.perturbed_attn: + video_self_attn_args["perturbation_mask"] = perturbation_mask + video_self_attn_args["all_perturbed"] = all_perturbed + + attn_hidden_states = self.attn1(**video_self_attn_args) hidden_states = hidden_states + attn_hidden_states * gate_msa - norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) - - num_audio_ada_params = self.audio_scale_shift_table.shape[0] - audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( - batch_size, temb_audio.size(1), num_audio_ada_params, -1 - ) + # 1.2. Audio Self-Attention + audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size) audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( - audio_ada_values.unbind(dim=2) + audio_ada_params[:6] ) + if self.audio_cross_attn_adaln: + audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9] + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - query_rotary_emb=audio_rotary_emb, - ) + audio_self_attn_args = { + "hidden_states": norm_audio_hidden_states, + "encoder_hidden_states": None, + "query_rotary_emb": audio_rotary_emb, + "attention_mask": audio_self_attention_mask, + } + if self.perturbed_attn: + audio_self_attn_args["perturbation_mask"] = perturbation_mask + audio_self_attn_args["all_perturbed"] = all_perturbed + + attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args) audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa - # 2. Video and Audio Cross-Attention with the text embeddings + # 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text) + if self.cross_attn_adaln: + video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size) + shift_text_kv, scale_text_kv = video_prompt_ada_params + + audio_prompt_ada_params = self.get_mod_params( + self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size + ) + audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params + + # 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text) norm_hidden_states = self.norm2(hidden_states) + if self.video_cross_attn_adaln: + norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q + if self.cross_attn_adaln: + encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv + attn_hidden_states = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, query_rotary_emb=None, attention_mask=encoder_attention_mask, ) + if self.video_cross_attn_adaln: + attn_hidden_states = attn_hidden_states * gate_text_q hidden_states = hidden_states + attn_hidden_states + # 2.2. Audio-Text Cross-Attention norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + if self.audio_cross_attn_adaln: + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q + if self.cross_attn_adaln: + audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv + attn_audio_hidden_states = self.audio_attn2( norm_audio_hidden_states, encoder_hidden_states=audio_encoder_hidden_states, query_rotary_emb=None, attention_mask=audio_encoder_attention_mask, ) + if self.audio_cross_attn_adaln: + attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q audio_hidden_states = audio_hidden_states + attn_audio_hidden_states # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention - norm_hidden_states = self.audio_to_video_norm(hidden_states) - norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - - # Combine global and per-layer cross attention modulation parameters - # Video - video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] - video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - - video_ca_scale_shift_table = ( - video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) - + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - video_ca_gate = ( - video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) - + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) - ).unbind(dim=2) - - video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table - a2v_gate = video_ca_gate[0].squeeze(2) - - # Audio - audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] - audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] - - audio_ca_scale_shift_table = ( - audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) - + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) - ).unbind(dim=2) - audio_ca_gate = ( - audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) - + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) - ).unbind(dim=2) - - audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table - v2a_gate = audio_ca_gate[0].squeeze(2) - - # Audio-to-Video Cross Attention: Q: Video; K,V: Audio - mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_a2v_ca_scale.squeeze(2) - ) + audio_a2v_ca_shift.squeeze(2) - - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - query_rotary_emb=ca_video_rotary_emb, - key_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + if use_a2v_cross_attention or use_v2a_cross_attention: + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) - hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + # 3.1. Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] - # Video-to-Audio Cross Attention: Q: Audio; K,V: Video - mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( - 2 - ) - mod_norm_audio_hidden_states = norm_audio_hidden_states * ( - 1 + audio_v2a_ca_scale.squeeze(2) - ) + audio_v2a_ca_shift.squeeze(2) - - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states, - encoder_hidden_states=mod_norm_hidden_states, - query_rotary_emb=ca_audio_rotary_emb, - key_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size) + video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params + a2v_gate = video_ca_gate_param[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_ada_params = self.get_mod_params( + audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size + ) + audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params + v2a_gate = audio_ca_gate_param[0].squeeze(2) + + # 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio + if use_a2v_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_a2v_ca_scale.squeeze(2) + ) + video_a2v_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video + if use_v2a_cross_attention: + mod_norm_hidden_states = norm_hidden_states * ( + 1 + video_v2a_ca_scale.squeeze(2) + ) + video_v2a_ca_shift.squeeze(2) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) - audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp @@ -948,6 +1168,8 @@ def __init__( pos_embed_max_pos: int = 20, base_height: int = 2048, base_width: int = 2048, + gated_attn: bool = False, + cross_attn_mod: bool = False, audio_in_channels: int = 128, # Audio Arguments audio_out_channels: int | None = 128, audio_patch_size: int = 1, @@ -959,6 +1181,8 @@ def __init__( audio_pos_embed_max_pos: int = 20, audio_sampling_rate: int = 16000, audio_hop_length: int = 160, + audio_gated_attn: bool = False, + audio_cross_attn_mod: bool = False, num_layers: int = 48, # Shared arguments activation_fn: str = "gelu-approximate", qk_norm: str = "rms_norm_across_heads", @@ -973,6 +1197,8 @@ def __init__( timestep_scale_multiplier: int = 1000, cross_attn_timestep_scale_multiplier: int = 1000, rope_type: str = "interleaved", + use_prompt_embeddings=True, + perturbed_attn: bool = False, ) -> None: super().__init__() @@ -986,17 +1212,25 @@ def __init__( self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) # 2. Prompt embeddings - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=audio_inner_dim - ) + if use_prompt_embeddings: + # LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) # 3. Timestep Modulation Params and Embedding + self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3 + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters - self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + video_time_emb_mod_params = 9 if cross_attn_mod else 6 + audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6 + self.time_embed = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False + ) self.audio_time_embed = LTX2AdaLayerNormSingle( - audio_inner_dim, num_mod_params=6, use_additional_conditions=False + audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False ) # 3.2. Global Cross Attention Modulation Parameters @@ -1025,6 +1259,13 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + # 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3) + if self.prompt_modulation: + self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False) + self.audio_prompt_adaln = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=2, use_additional_conditions=False + ) + # 4. Rotary Positional Embeddings (RoPE) # Self-Attention self.rope = LTX2AudioVideoRotaryPosEmbed( @@ -1101,6 +1342,10 @@ def __init__( audio_num_attention_heads=audio_num_attention_heads, audio_attention_head_dim=audio_attention_head_dim, audio_cross_attention_dim=audio_cross_attention_dim, + video_gated_attn=gated_attn, + video_cross_attn_adaln=cross_attn_mod, + audio_gated_attn=audio_gated_attn, + audio_cross_attn_adaln=audio_cross_attn_mod, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1108,6 +1353,7 @@ def __init__( eps=norm_eps, elementwise_affine=norm_elementwise_affine, rope_type=rope_type, + perturbed_attn=perturbed_attn, ) for _ in range(num_layers) ] @@ -1130,6 +1376,39 @@ def _set_gradient_checkpointing(self, *args, **kwargs): else: raise ValueError("Invalid set gradient checkpointing") + def enable_multi_gpus_inference(self): + """Enable multi-GPU inference using sequence parallel.""" + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + + # Choose processor based on perturbed_attn config + if self.config.perturbed_attn: + processor_cls = LTX2PerturbedMultiGPUsAttnProcessor + else: + processor_cls = LTX2MultiGPUsAttnProcessor + + # Set multi-GPU processor for attention layers that need it + # Note: text cross-attention (attn2, audio_attn2) does NOT need special handling + # because text embeddings are not chunked - each GPU has full text embeddings + # Audio self-attention (audio_attn1) and audio-to-video cross-attention also use standard + # processors since they don't require cross-rank communication: + # - audio_attn1: Q,K,V are all full audio (no chunking) + # - audio_to_video_attn: Q=video(chunked), K,V=audio(full) - each rank computes locally + for block in self.transformer_blocks: + # Video self-attention (chunked) -> multi-GPU processor for all-to-all communication + block.attn1.set_processor(processor_cls()) + block.attn1.sp_world_size = self.sp_world_size + block.attn1.sp_world_rank = self.sp_world_rank + block.attn1.all_gather = self.all_gather + + # Video-to-audio cross-attention: Q=audio(full), K,V=video(chunked) + # Needs to all_gather K,V from video chunks across ranks + block.video_to_audio_attn.set_processor(processor_cls()) + block.video_to_audio_attn.sp_world_size = self.sp_world_size + block.video_to_audio_attn.sp_world_rank = self.sp_world_rank + block.video_to_audio_attn.all_gather = self.all_gather + def forward( self, hidden_states: torch.Tensor, @@ -1138,6 +1417,8 @@ def forward( audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, audio_timestep: torch.LongTensor | None = None, + sigma: torch.Tensor | None = None, + audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, num_frames: int | None = None, @@ -1147,6 +1428,10 @@ def forward( audio_num_frames: int | None = None, video_coords: torch.Tensor | None = None, audio_coords: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + perturbation_mask: torch.Tensor | None = None, + use_cross_timestep: bool = False, attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> torch.Tensor: @@ -1168,6 +1453,13 @@ def forward( audio_timestep (`torch.Tensor`, *optional*): Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation params. This is only used by certain pipelines such as the I2V pipeline. + sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in + models such as LTX-2.3. + audio_sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in + models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to + the provided `sigma` value. encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): @@ -1189,6 +1481,21 @@ def forward( audio_coords (`torch.Tensor`, *optional*): The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + isolate_modalities (`bool`, *optional*, defaults to `False`): + Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio) + cross attention (for all blocks). Use for modality guidance in LTX-2.3. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the + self-attention operations by simply using the values rather than the full scaled dot-product attention + (SDPA) operation. If `None` or empty, STG will not be applied to any block. + perturbation_mask (`torch.Tensor`, *optional*): + Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch + elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is + not supplied, will default to applying STG (perturbing) all batch elements. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. attention_kwargs (`dict[str, Any]`, *optional*): Optional dict of keyword args to be passed to the attention processor. return_dict (`bool`, *optional*, defaults to `True`): @@ -1202,6 +1509,7 @@ def forward( """ # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep + audio_sigma = audio_sigma if audio_sigma is not None else sigma # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: @@ -1214,7 +1522,7 @@ def forward( batch_size = hidden_states.size(0) - # 1. Prepare RoPE positional embeddings + # 1. Prepare coords for RoPE (will generate embeddings after patchification and sequence chunking) if video_coords is None: video_coords = self.rope.prepare_video_coords( batch_size, num_frames, height, width, hidden_states.device, fps=fps @@ -1224,6 +1532,30 @@ def forward( batch_size, audio_num_frames, audio_hidden_states.device ) + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # Sequence parallel: chunk video and audio hidden states and coords + if hasattr(self, 'sp_world_size') and self.sp_world_size > 1: + from ..dist.fuser import sequence_parallel_chunk, get_sequence_parallel_rank, get_sequence_parallel_world_size + + sp_rank = get_sequence_parallel_rank() + sp_world_size = get_sequence_parallel_world_size() + + # Chunk video hidden states + hidden_states = sequence_parallel_chunk(hidden_states, dim=1) + + # Chunk video coords for RoPE + if video_coords is not None: + video_coords_chunks = torch.chunk(video_coords, sp_world_size, dim=2) + video_coords = video_coords_chunks[sp_rank] + + # Chunk per-patch timestep for I2V (timestep shape [B, num_patches]) + if timestep.ndim == 2: + timestep = sequence_parallel_chunk(timestep, dim=1) + + # 3. Prepare RoPE positional embeddings (after sequence chunking so RoPE matches chunked sequence length) video_rotary_emb = self.rope(video_coords, device=hidden_states.device) audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) @@ -1232,11 +1564,7 @@ def forward( audio_coords[:, 0:1, :], device=audio_hidden_states.device ) - # 2. Patchify input projections - hidden_states = self.proj_in(hidden_states) - audio_hidden_states = self.audio_proj_in(audio_hidden_states) - - # 3. Prepare timestep embeddings and modulation parameters + # 4. Prepare timestep embeddings and modulation parameters timestep_cross_attn_gate_scale_factor = ( self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier ) @@ -1260,14 +1588,28 @@ def forward( temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + if self.prompt_modulation: + # LTX-2.3 + temb_prompt, _ = self.prompt_adaln( + sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + temb_prompt_audio, _ = self.audio_prompt_adaln( + audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype + ) + temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1)) + temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1)) + else: + temb_prompt = temb_prompt_audio = None + # 3.2. Prepare global modality cross attention modulation parameters + video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten() video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), + video_ca_timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, + video_ca_timestep * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) @@ -1276,13 +1618,14 @@ def forward( ) video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten() audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), + audio_ca_timestep, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + audio_ca_timestep * timestep_cross_attn_gate_scale_factor, batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, ) @@ -1291,20 +1634,36 @@ def forward( ) audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) - # 4. Prepare prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + # 4. Prepare prompt embeddings (LTX-2.0) + if self.config.use_prompt_embeddings: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) # 5. Run transformer blocks - for block in self.transformer_blocks: + spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or [] + if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None: + # If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements. + perturbation_mask = torch.zeros((batch_size,)) + if perturbation_mask is not None and perturbation_mask.ndim == 1: + perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + stg_blocks = set(spatio_temporal_guidance_blocks) + + for block_idx, block in enumerate(self.transformer_blocks): + block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None + block_all_perturbed = all_perturbed if block_idx in stg_blocks else False + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states, audio_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), @@ -1318,13 +1677,23 @@ def custom_forward(*inputs): audio_cross_attn_scale_shift, video_cross_attn_a2v_gate, audio_cross_attn_v2a_gate, + temb_prompt, + temb_prompt_audio, video_rotary_emb, audio_rotary_emb, video_cross_attn_rotary_emb, audio_cross_attn_rotary_emb, encoder_attention_mask, audio_encoder_attention_mask, - use_reentrant=False, + None, # self_attention_mask + None, # audio_self_attention_mask + None, # a2v_cross_attention_mask + None, # v2a_cross_attention_mask + not isolate_modalities, # use_a2v_cross_attention + not isolate_modalities, # use_v2a_cross_attention + block_perturbation_mask, + block_all_perturbed, + **ckpt_kwargs ) else: hidden_states, audio_hidden_states = block( @@ -1338,12 +1707,22 @@ def custom_forward(*inputs): temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, temb_ca_gate=video_cross_attn_a2v_gate, temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, video_rotary_emb=video_rotary_emb, audio_rotary_emb=audio_rotary_emb, ca_video_rotary_emb=video_cross_attn_rotary_emb, ca_audio_rotary_emb=audio_cross_attn_rotary_emb, encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, + self_attention_mask=None, + audio_self_attention_mask=None, + a2v_cross_attention_mask=None, + v2a_cross_attention_mask=None, + use_a2v_cross_attention=not isolate_modalities, + use_v2a_cross_attention=not isolate_modalities, + perturbation_mask=block_perturbation_mask, + all_perturbed=block_all_perturbed, ) # 6. Output layers (including unpatchification) @@ -1361,6 +1740,12 @@ def custom_forward(*inputs): audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift audio_output = self.audio_proj_out(audio_hidden_states) + # Sequence parallel: gather video outputs from all ranks + # Audio output is NOT gathered since it was never chunked + if hasattr(self, 'sp_world_size') and self.sp_world_size > 1: + from ..dist.fuser import sequence_parallel_all_gather + output = sequence_parallel_all_gather(output, dim=1) + if not return_dict: return (output, audio_output) return AudioVisualModelOutput(sample=output, audio_sample=audio_output) \ No newline at end of file diff --git a/videox_fun/models/ltx2_vae.py b/videox_fun/models/ltx2_vae.py index 74ede24b..ca1f5e53 100644 --- a/videox_fun/models/ltx2_vae.py +++ b/videox_fun/models/ltx2_vae.py @@ -236,7 +236,7 @@ def forward( # Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d -class LTXVideoDownsampler3d(nn.Module): +class LTX2VideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -284,10 +284,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten # Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d -class LTXVideoUpsampler3d(nn.Module): +class LTX2VideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, + out_channels: int | None = None, stride: int | tuple[int, int, int] = 1, residual: bool = False, upscale_factor: int = 1, @@ -299,7 +300,8 @@ def __init__( self.residual = residual self.upscale_factor = upscale_factor - out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + out_channels = out_channels or in_channels + out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, @@ -407,7 +409,7 @@ def __init__( ) elif downsample_type == "spatial": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), @@ -416,7 +418,7 @@ def __init__( ) elif downsample_type == "temporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), @@ -425,7 +427,7 @@ def __init__( ) elif downsample_type == "spatiotemporal": self.downsamplers.append( - LTXVideoDownsampler3d( + LTX2VideoDownsampler3d( in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), @@ -579,6 +581,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, + upsample_type: str = "spatiotemporal", inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, @@ -608,16 +611,23 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList( - [ - LTXVideoUpsampler3d( - out_channels * upscale_factor, - stride=(2, 2, 2), - residual=upsample_residual, - upscale_factor=upscale_factor, - spatial_padding_mode=spatial_padding_mode, - ) - ] + self.upsamplers = nn.ModuleList() + + if upsample_type == "spatial": + upsample_stride = (1, 2, 2) + elif upsample_type == "temporal": + upsample_stride = (2, 1, 1) + elif upsample_type == "spatiotemporal": + upsample_stride = (2, 2, 2) + + self.upsamplers.append( + LTX2VideoUpsampler3d( + in_channels=out_channels * upscale_factor, + stride=upsample_stride, + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) ) resnets = [] @@ -715,7 +725,7 @@ def __init__( "LTX2VideoDownBlock3D", "LTX2VideoDownBlock3D", ), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, @@ -725,6 +735,9 @@ def __init__( spatial_padding_mode: str = "zeros", ): super().__init__() + num_encoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -859,19 +872,27 @@ def __init__( in_channels: int = 128, out_channels: int = 3, block_out_channels: tuple[int, ...] = (256, 512, 1024), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = False, - inject_noise: tuple[bool, ...] = (False, False, False), + inject_noise: bool | tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[bool, ...] = (2, 2, 2), spatial_padding_mode: str = "reflect", ) -> None: super().__init__() + num_decoder_blocks = len(layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(inject_noise, bool): + inject_noise = (inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.patch_size = patch_size self.patch_size_t = patch_size_t @@ -916,6 +937,7 @@ def __init__( num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], + upsample_type=upsample_type[i], inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], @@ -1057,11 +1079,12 @@ def __init__( decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), - spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), - decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), - decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), + spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), + decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False), downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), - upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: bool | tuple[bool, ...] = (True, True, True), upsample_factor: tuple[int, ...] = (2, 2, 2), timestep_conditioning: bool = False, patch_size: int = 4, @@ -1076,6 +1099,16 @@ def __init__( temporal_compression_ratio: int = None, ) -> None: super().__init__() + num_encoder_blocks = len(layers_per_block) + num_decoder_blocks = len(decoder_layers_per_block) + if isinstance(spatio_temporal_scaling, bool): + spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) + if isinstance(decoder_spatio_temporal_scaling, bool): + decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1) + if isinstance(decoder_inject_noise, bool): + decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks + if isinstance(upsample_residual, bool): + upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) self.encoder = LTX2VideoEncoder3d( in_channels=in_channels, @@ -1097,6 +1130,7 @@ def __init__( block_out_channels=decoder_block_out_channels, spatio_temporal_scaling=decoder_spatio_temporal_scaling, layers_per_block=decoder_layers_per_block, + upsample_type=upsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, @@ -1516,4 +1550,4 @@ def forward( dec = self.decode(z, temb, causal=decoder_causal) if not return_dict: return (dec.sample,) - return dec + return dec \ No newline at end of file diff --git a/videox_fun/models/ltx2_vocoder.py b/videox_fun/models/ltx2_vocoder.py index 7e038827..65d96626 100644 --- a/videox_fun/models/ltx2_vocoder.py +++ b/videox_fun/models/ltx2_vocoder.py @@ -4,10 +4,214 @@ import torch import torch.nn as nn import torch.nn.functional as F + from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + """ + Creates a Kaiser sinc kernel for low-pass filtering. + + Args: + cutoff (`float`): + Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist + frequency). + half_width (`float`): + Used to determine the Kaiser window's beta parameter. + kernel_size: + Size of the Kaiser window (and ultimately the Kaiser sinc kernel). + + Returns: + `torch.Tensor` of shape `(kernel_size,)`: + The Kaiser sinc kernel. + """ + delta_f = 4 * half_width + half_size = kernel_size // 2 + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + even = kernel_size % 2 == 0 + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + + if cutoff == 0.0: + filter = torch.zeros_like(time) + else: + time = 2 * cutoff * time + sinc = torch.where( + time == 0, + torch.ones_like(time), + torch.sin(math.pi * time) / math.pi / time, + ) + filter = 2 * cutoff * window * sinc + filter = filter / filter.sum() + return filter + + +class DownSample1d(nn.Module): + """1D low-pass filter for antialias downsampling.""" + + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + use_padding: bool = True, + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.kernel_size = kernel_size or int(6 * ratio // 2) * 2 + self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1 + self.pad_right = self.kernel_size // 2 + self.use_padding = use_padding + self.padding_mode = padding_mode + + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size) + self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + if self.use_padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels) + return x_filtered + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + window_type: str = "kaiser", + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.padding_mode = padding_mode + + if window_type == "hann": + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser sinc filter is BigVGAN default + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2 + self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2 + + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode) + low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels) + return x[..., self.pad_left : -self.pad_right] + + +class AntiAliasAct1d(nn.Module): + """ + Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples + to avoid aliasing. + """ + + def __init__( + self, + act_fn: str | nn.Module, + ratio: int = 2, + kernel_size: int = 12, + **kwargs, + ): + super().__init__() + self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size) + if isinstance(act_fn, str): + if act_fn == "snakebeta": + act_fn = SnakeBeta(**kwargs) + elif act_fn == "snake": + act_fn = SnakeBeta(**kwargs) + else: + act_fn = nn.LeakyReLU(**kwargs) + self.act = act_fn + self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +class SnakeBeta(nn.Module): + """ + Implements the Snake and SnakeBeta activations, which help with learning periodic patterns. + """ + + def __init__( + self, + channels: int, + alpha: float = 1.0, + eps: float = 1e-9, + trainable_params: bool = True, + logscale: bool = True, + use_beta: bool = True, + ): + super().__init__() + self.eps = eps + self.logscale = logscale + self.use_beta = use_beta + + self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.alpha.requires_grad = trainable_params + if use_beta: + self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.beta.requires_grad = trainable_params + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + broadcast_shape = [1] * hidden_states.ndim + broadcast_shape[channel_dim] = -1 + alpha = self.alpha.view(broadcast_shape) + if self.use_beta: + beta = self.beta.view(broadcast_shape) + + if self.logscale: + alpha = torch.exp(alpha) + if self.use_beta: + beta = torch.exp(beta) + + amplitude = beta if self.use_beta else alpha + hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2) + return hidden_states + + class ResBlock(nn.Module): def __init__( self, @@ -15,12 +219,15 @@ def __init__( kernel_size: int = 3, stride: int = 1, dilations: tuple[int, ...] = (1, 3, 5), + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, padding_mode: str = "same", ): super().__init__() self.dilations = dilations - self.negative_slope = leaky_relu_negative_slope self.convs1 = nn.ModuleList( [ @@ -28,6 +235,18 @@ def __init__( for dilation in dilations ] ) + self.acts1 = nn.ModuleList() + for _ in range(len(self.convs1)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts1.append(act) self.convs2 = nn.ModuleList( [ @@ -35,12 +254,24 @@ def __init__( for _ in range(len(dilations)) ] ) + self.acts2 = nn.ModuleList() + for _ in range(len(self.convs2)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts2.append(act) def forward(self, x: torch.Tensor) -> torch.Tensor: - for conv1, conv2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, negative_slope=self.negative_slope) + for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2): + xt = act1(x) xt = conv1(xt) - xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = act2(xt) xt = conv2(xt) x = x + xt return x @@ -61,7 +292,13 @@ def __init__( upsample_factors: list[int] = [6, 5, 2, 2, 2], resnet_kernel_sizes: list[int] = [3, 7, 11], resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "leaky_relu", leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = "tanh", # tanh, clamp, None + final_bias: bool = True, output_sampling_rate: int = 24000, ): super().__init__() @@ -69,7 +306,9 @@ def __init__( self.resnets_per_upsample = len(resnet_kernel_sizes) self.out_channels = out_channels self.total_upsample_factor = math.prod(upsample_factors) + self.act_fn = act_fn self.negative_slope = leaky_relu_negative_slope + self.final_act_fn = final_act_fn if self.num_upsample_layers != len(upsample_factors): raise ValueError( @@ -83,6 +322,13 @@ def __init__( f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." ) + supported_act_fns = ["snakebeta", "snake", "leaky_relu"] + if self.act_fn not in supported_act_fns: + raise ValueError( + f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are " + f"{supported_act_fns}." + ) + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) self.upsamplers = nn.ModuleList() @@ -103,15 +349,27 @@ def __init__( for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): self.resnets.append( ResBlock( - output_channels, - kernel_size, + channels=output_channels, + kernel_size=kernel_size, dilations=dilations, + act_fn=act_fn, leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, ) ) input_channels = output_channels - self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + if act_fn == "snakebeta" or act_fn == "snake": + # Always use antialiasing + act_out = SnakeBeta(channels=output_channels, use_beta=True) + self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + elif act_fn == "leaky_relu": + # NOTE: does NOT use self.negative_slope, following the original code + self.act_out = nn.LeakyReLU() + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias) def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: r""" @@ -139,7 +397,9 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = self.conv_in(hidden_states) for i in range(self.num_upsample_layers): - hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + if self.act_fn == "leaky_relu": + # Other activations are inside each upsampling block + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) hidden_states = self.upsamplers[i](hidden_states) # Run all resnets in parallel on hidden_states @@ -149,10 +409,190 @@ def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch hidden_states = torch.mean(resnet_outputs, dim=0) - # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of - # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended - hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.act_out(hidden_states) hidden_states = self.conv_out(hidden_states) - hidden_states = torch.tanh(hidden_states) + if self.final_act_fn == "tanh": + hidden_states = torch.tanh(hidden_states) + elif self.final_act_fn == "clamp": + hidden_states = torch.clamp(hidden_states, -1, 1) + + return hidden_states + + +class CausalSTFT(nn.Module): + """ + Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases + multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact + buffers should be loaded from the checkpoint in bfloat16. + """ + + def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512): + super().__init__() + self.hop_length = hop_length + self.window_length = window_length + n_freqs = filter_length // 2 + 1 + + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if waveform.ndim == 2: + waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples] + + left_pad = max(0, self.window_length - self.hop_length) # causal: left-only + waveform = F.pad(waveform, (left_pad, 0)) + + spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """ + Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be + loaded from the checkpoint in bfloat16. + """ + + def __init__( + self, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + ): + super().__init__() + self.stft_fn = CausalSTFT(filter_length, hop_length, window_length) + + num_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + magnitude, phase = self.stft_fn(waveform) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class LTX2VocoderWithBWE(ModelMixin, ConfigMixin): + """ + LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the + BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same + architecture as the original vocoder. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1536, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4], + upsample_factors: list[int] = [5, 2, 2, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "snakebeta", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = True, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = None, + final_bias: bool = False, + bwe_in_channels: int = 128, + bwe_hidden_channels: int = 512, + bwe_out_channels: int = 2, + bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4], + bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2], + bwe_resnet_kernel_sizes: list[int] = [3, 7, 11], + bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + bwe_act_fn: str = "snakebeta", + bwe_leaky_relu_negative_slope: float = 0.1, + bwe_antialias: bool = True, + bwe_antialias_ratio: int = 2, + bwe_antialias_kernel_size: int = 12, + bwe_final_act_fn: str | None = None, + bwe_final_bias: bool = False, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + ): + super().__init__() + + self.vocoder = LTX2Vocoder( + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + upsample_kernel_sizes=upsample_kernel_sizes, + upsample_factors=upsample_factors, + resnet_kernel_sizes=resnet_kernel_sizes, + resnet_dilations=resnet_dilations, + act_fn=act_fn, + leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, + final_act_fn=final_act_fn, + final_bias=final_bias, + output_sampling_rate=input_sampling_rate, + ) + self.bwe_generator = LTX2Vocoder( + in_channels=bwe_in_channels, + hidden_channels=bwe_hidden_channels, + out_channels=bwe_out_channels, + upsample_kernel_sizes=bwe_upsample_kernel_sizes, + upsample_factors=bwe_upsample_factors, + resnet_kernel_sizes=bwe_resnet_kernel_sizes, + resnet_dilations=bwe_resnet_dilations, + act_fn=bwe_act_fn, + leaky_relu_negative_slope=bwe_leaky_relu_negative_slope, + antialias=bwe_antialias, + antialias_ratio=bwe_antialias_ratio, + antialias_kernel_size=bwe_antialias_kernel_size, + final_act_fn=bwe_final_act_fn, + final_bias=bwe_final_bias, + output_sampling_rate=output_sampling_rate, + ) + + self.mel_stft = MelSTFT( + filter_length=filter_length, + hop_length=hop_length, + window_length=window_length, + num_mel_channels=num_mel_channels, + ) + + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, + window_type="hann", + persistent=False, + ) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + # 1. Run stage 1 vocoder to get low sampling rate waveform + x = self.vocoder(mel_spec) + batch_size, num_channels, num_samples = x.shape + + # Pad to exact multiple of hop_length for exact mel frame count + remainder = num_samples % self.config.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # 2. Compute mel spectrogram on vocoder output + mel, _, _, _ = self.mel_stft(x.flatten(0, 1)) + mel = mel.unflatten(0, (-1, num_channels)) + + # 3. Run bandwidth extender (BWE) on new mel spectrogram + mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins] + residual = self.bwe_generator(mel_for_bwe) - return hidden_states \ No newline at end of file + # 4. Residual connection with resampler + skip = self.resampler(x) + waveform = torch.clamp(residual + skip, -1, 1) + output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate + waveform = waveform[..., :output_samples] + return waveform \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_ltx2.py b/videox_fun/pipeline/pipeline_ltx2.py index 680de954..06fd73da 100644 --- a/videox_fun/pipeline/pipeline_ltx2.py +++ b/videox_fun/pipeline/pipeline_ltx2.py @@ -30,10 +30,10 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, +from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, Gemma3Processor, Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast, LTX2TextConnectors, - LTX2VideoTransformer3DModel, LTX2Vocoder) + LTX2VideoTransformer3DModel, LTX2Vocoder, LTX2VocoderWithBWE) if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -242,7 +242,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin): """ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" - _optional_components = [] + _optional_components = ["processor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -254,7 +254,8 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + processor: Gemma3Processor | None = None, ): super().__init__() @@ -267,6 +268,7 @@ def __init__( transformer=transformer, vocoder=vocoder, scheduler=scheduler, + processor=processor, ) self.vae_spatial_compression_ratio = ( @@ -301,73 +303,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - def _get_gemma_prompt_embeds( self, prompt: str | list[str], @@ -420,16 +355,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -527,6 +453,50 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + @torch.no_grad() + def enhance_prompt( + self, + prompt: str, + system_prompt: str, + max_new_tokens: int = 512, + seed: int = 10, + generator: torch.Generator | None = None, + generation_kwargs: dict[str, Any] | None = None, + device: str | torch.device | None = None, + ): + """ + Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a + `transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt. + """ + device = device or self._execution_device + if generation_kwargs is None: + # Set to default generation kwargs + generation_kwargs = {"do_sample": True, "temperature": 0.7} + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user prompt: {prompt}"}, + ] + template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device) + self.text_encoder.to(device) + + # `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness, + # so manually apply a seed for reproducible generation. + if generator is not None: + # Overwrite seed to generator's initial seed + seed = generator.initial_seed() + torch.manual_seed(seed) + generated_sequences = self.text_encoder.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + **generation_kwargs, + ) # tensor of shape [batch_size, seq_len] + + generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)] + enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return enhanced_prompt + def check_inputs( self, prompt, @@ -537,6 +507,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -580,6 +553,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. @@ -616,6 +595,7 @@ def _unpack_latents( return latents @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 ) -> torch.Tensor: @@ -766,7 +746,6 @@ def prepare_audio_latents( latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) - # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) @@ -781,6 +760,24 @@ def prepare_audio_latents( latents = self._pack_audio_latents(latents) return latents + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + @property def guidance_scale(self): return self._guidance_scale @@ -789,9 +786,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -821,9 +850,16 @@ def __call__( frame_rate: float = 24.0, num_inference_steps: int = 40, sigmas: list[float] | None = None, - timesteps: list[int] | None = None, + timesteps: list[int] = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -835,6 +871,11 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + system_prompt: str | None = None, + prompt_max_new_tokens: int = 512, + prompt_enhancement_kwargs: dict[str, Any] | None = None, + prompt_enhancement_seed: int = 10, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, @@ -864,7 +905,7 @@ def __call__( Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - timesteps (`List[int]`, *optional*): + timesteps (`list[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. @@ -873,13 +914,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -910,6 +985,24 @@ def __call__( The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. + system_prompt (`str`, *optional*, defaults to `None`): + Optional system prompt to use for prompt enhancement. The system prompt will be used by the current + text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from + the original `prompt` to condition generation. If not supplied, prompt enhancement will not be + performed. + prompt_max_new_tokens (`int`, *optional*, defaults to `512`): + The maximum number of new tokens to generate when performing prompt enhancement. + prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`): + Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of + `do_sample=True` and `temperature=0.7` will be used. See + https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate + for more details. + prompt_enhancement_seed (`int`, *optional*, default to `10`): + Random seed for any random operations during prompt enhancement. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -924,7 +1017,7 @@ def __call__( with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. @@ -942,6 +1035,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -952,10 +1050,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -971,6 +1080,17 @@ def __call__( device = self._execution_device # 3. Prepare text embeddings + if system_prompt is not None and prompt is not None: + prompt = self.enhance_prompt( + prompt=prompt, + system_prompt=system_prompt, + max_new_tokens=prompt_max_new_tokens, + seed=prompt_enhancement_seed, + generator=generator, + generation_kwargs=prompt_enhancement_kwargs, + device=device, + ) + ( prompt_embeds, prompt_attention_mask, @@ -992,9 +1112,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1016,7 +1138,6 @@ def __call__( raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." ) - video_sequence_length = latent_num_frames * latent_height * latent_width num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -1073,7 +1194,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( - video_sequence_length, + self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.95), @@ -1102,11 +1223,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1143,6 +1259,7 @@ def __call__( encoder_hidden_states=connector_prompt_embeds, audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1152,7 +1269,10 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1160,24 +1280,153 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler ) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text ) - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale - ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale - ) + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + if self.do_spatio_temporal_guidance: + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + noise_pred_video_uncond_modality = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_modality, i, self.scheduler + ) + noise_pred_audio_uncond_modality = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta + + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale + ) + else: + noise_pred_video = noise_pred_video_g + + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g + + # Convert back to velocity for scheduler + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] @@ -1209,9 +1458,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) audio_latents = self._denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std @@ -1219,6 +1465,9 @@ def __call__( audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) video = latents audio = audio_latents else: @@ -1241,6 +1490,10 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type="pt").cpu().float().permute(0, 2, 1, 3, 4) diff --git a/videox_fun/pipeline/pipeline_ltx2_i2v.py b/videox_fun/pipeline/pipeline_ltx2_i2v.py index b1083355..d39b1065 100644 --- a/videox_fun/pipeline/pipeline_ltx2_i2v.py +++ b/videox_fun/pipeline/pipeline_ltx2_i2v.py @@ -30,10 +30,10 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, +from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, Gemma3Processor, Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast, LTX2TextConnectors, - LTX2VideoTransformer3DModel, LTX2Vocoder) + LTX2VideoTransformer3DModel, LTX2Vocoder, LTX2VocoderWithBWE) if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -47,6 +47,42 @@ EXAMPLE_DOC_STRING = """ Examples: ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) ``` """ @@ -194,7 +230,7 @@ class LTX2I2VPipeline(DiffusionPipeline, FromSingleFileMixin): """ model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" - _optional_components = [] + _optional_components = ["processor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -206,7 +242,8 @@ def __init__( tokenizer: GemmaTokenizer | GemmaTokenizerFast, connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, - vocoder: LTX2Vocoder, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + processor: Gemma3Processor | None = None, ): super().__init__() @@ -219,6 +256,7 @@ def __init__( transformer=transformer, vocoder=vocoder, scheduler=scheduler, + processor=processor, ) self.vae_spatial_compression_ratio = ( @@ -253,74 +291,6 @@ def __init__( self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) - @staticmethod - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds - def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -349,21 +319,39 @@ def _get_gemma_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - if getattr(self, "tokenizer", None) is not None: - # Gemma expects left padding for chat-style prompts - self.tokenizer.padding_side = "left" - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - prompt = [p.strip() for p in prompt] - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) + + # Use processor if available (LTX-2.3), otherwise use tokenizer directly + if getattr(self, "processor", None) is not None: + # LTX-2.3: Use processor with chat template + if getattr(self.processor, "tokenizer", None) is not None: + self.processor.tokenizer.padding_side = "left" + if self.processor.tokenizer.pad_token is None: + self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token + + text_inputs = self.processor( + text=prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + else: + # Legacy: Use tokenizer directly + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask text_input_ids = text_input_ids.to(device) @@ -374,16 +362,7 @@ def _get_gemma_prompt_embeds( ) text_encoder_hidden_states = text_encoder_outputs.hidden_states text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = prompt_attention_mask.sum(dim=-1) - - prompt_embeds = self._pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ) - prompt_embeds = prompt_embeds.to(dtype=dtype) + prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape @@ -482,6 +461,57 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + @torch.no_grad() + def enhance_prompt( + self, + image: PipelineImageInput, + prompt: str, + system_prompt: str, + max_new_tokens: int = 512, + seed: int = 10, + generator: torch.Generator | None = None, + generation_kwargs: dict[str, Any] | None = None, + device: str | torch.device | None = None, + ): + """ + Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a + `transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt. + """ + device = device or self._execution_device + if generation_kwargs is None: + # Set to default generation kwargs + generation_kwargs = {"do_sample": True, "temperature": 0.7} + + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": f"User Raw Input Prompt: {prompt}."}, + ], + }, + ] + template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device) + self.text_encoder.to(device) + + # `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness, + # so manually apply a seed for reproducible generation. + if generator is not None: + # Overwrite seed to generator's initial seed + seed = generator.initial_seed() + torch.manual_seed(seed) + generated_sequences = self.text_encoder.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + **generation_kwargs, + ) # tensor of shape [batch_size, seq_len] + + generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)] + enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return enhanced_prompt + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs def check_inputs( self, @@ -493,6 +523,9 @@ def check_inputs( negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + spatio_temporal_guidance_blocks=None, + stg_scale=None, + audio_stg_scale=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -536,6 +569,12 @@ def check_inputs( f" {negative_prompt_attention_mask.shape}." ) + if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks: + raise ValueError( + "Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of" + "block indices at which to apply STG in `spatio_temporal_guidance_blocks`" + ) + @staticmethod # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: @@ -770,7 +809,6 @@ def prepare_audio_latents( latents = self._create_noised_state(latents, noise_scale, generator) return latents.to(device=device, dtype=dtype) - # TODO: confirm whether this logic is correct latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) @@ -785,6 +823,24 @@ def prepare_audio_latents( latents = self._pack_audio_latents(latents) return latents + def convert_velocity_to_x0( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx] + return sample_x0 + + def convert_x0_to_velocity( + self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None + ) -> torch.Tensor: + if scheduler is None: + scheduler = self.scheduler + + sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx] + return sample_v + @property def guidance_scale(self): return self._guidance_scale @@ -793,9 +849,41 @@ def guidance_scale(self): def guidance_rescale(self): return self._guidance_rescale + @property + def stg_scale(self): + return self._stg_scale + + @property + def modality_scale(self): + return self._modality_scale + + @property + def audio_guidance_scale(self): + return self._audio_guidance_scale + + @property + def audio_guidance_rescale(self): + return self._audio_guidance_rescale + + @property + def audio_stg_scale(self): + return self._audio_stg_scale + + @property + def audio_modality_scale(self): + return self._audio_modality_scale + @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 + return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0) + + @property + def do_spatio_temporal_guidance(self): + return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0) + + @property + def do_modality_isolation_guidance(self): + return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0) @property def num_timesteps(self): @@ -828,7 +916,14 @@ def __call__( sigmas: list[float] | None = None, timesteps: list[int] | None = None, guidance_scale: float = 4.0, + stg_scale: float = 0.0, + modality_scale: float = 1.0, guidance_rescale: float = 0.0, + audio_guidance_scale: float | None = None, + audio_stg_scale: float | None = None, + audio_modality_scale: float | None = None, + audio_guidance_rescale: float | None = None, + spatio_temporal_guidance_blocks: list[int] | None = None, noise_scale: float = 0.0, num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -840,6 +935,11 @@ def __call__( negative_prompt_attention_mask: torch.Tensor | None = None, decode_timestep: float | list[float] = 0.0, decode_noise_scale: float | list[float] | None = None, + use_cross_timestep: bool = False, + system_prompt: str | None = None, + prompt_max_new_tokens: int = 512, + prompt_enhancement_kwargs: dict[str, Any] | None = None, + prompt_enhancement_seed: int = 10, output_type: str = "pil", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, @@ -880,13 +980,47 @@ def __call__( Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is + a separate value `audio_guidance_scale` for the audio modality). + stg_scale (`float`, *optional*, defaults to `0.0`): + Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for + Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate + where we move the sample away from a weak sample from a perturbed version of the denoising model. + Enabling STG will result in an additional denoising model forward pass; the default value of `0.0` + means that STG is disabled. + modality_scale (`float`, *optional*, defaults to `1.0`): + Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a + weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio) + cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an + additional denoising model forward pass; the default value of `1.0` means that modality guidance is + disabled. guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when - using zero terminal SNR. + using zero terminal SNR. Used for the video modality. + audio_guidance_scale (`float`, *optional* defaults to `None`): + Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for + video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest + that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for + LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value + `guidance_scale`. + audio_stg_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and + audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the + video value `stg_scale`. + audio_modality_scale (`float`, *optional*, defaults to `None`): + Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule + is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and + audio. If `None`, defaults to the video value `modality_scale`. + audio_guidance_rescale (`float`, *optional*, defaults to `None`): + A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value + `guidance_rescale`. + spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`): + The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used + (`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0 + and `[28]` is recommended for LTX-2.3. noise_scale (`float`, *optional*, defaults to `0.0`): The interpolation factor between random noise and denoised latents at each timestep. Applying noise to the `latents` and `audio_latents` before continue denoising. @@ -917,6 +1051,24 @@ def __call__( The timestep at which generated video is decoded. decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. + use_cross_timestep (`bool` *optional*, defaults to `False`): + Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when + calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior; + `False` is the legacy LTX-2.0 behavior. + system_prompt (`str`, *optional*, defaults to `None`): + Optional system prompt to use for prompt enhancement. The system prompt will be used by the current + text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from + the original `prompt` to condition generation. If not supplied, prompt enhancement will not be + performed. + prompt_max_new_tokens (`int`, *optional*, defaults to `512`): + The maximum number of new tokens to generate when performing prompt enhancement. + prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`): + Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of + `do_sample=True` and `temperature=0.7` will be used. See + https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate + for more details. + prompt_enhancement_seed (`int`, *optional*, default to `10`): + Random seed for any random operations during prompt enhancement. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -949,6 +1101,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + audio_guidance_scale = audio_guidance_scale or guidance_scale + audio_stg_scale = audio_stg_scale or stg_scale + audio_modality_scale = audio_modality_scale or modality_scale + audio_guidance_rescale = audio_guidance_rescale or guidance_rescale + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, @@ -959,10 +1116,21 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + stg_scale=stg_scale, + audio_stg_scale=audio_stg_scale, ) + # Per-modality guidance scales (video, audio) self._guidance_scale = guidance_scale + self._stg_scale = stg_scale + self._modality_scale = modality_scale self._guidance_rescale = guidance_rescale + self._audio_guidance_scale = audio_guidance_scale + self._audio_stg_scale = audio_stg_scale + self._audio_modality_scale = audio_modality_scale + self._audio_guidance_rescale = audio_guidance_rescale + self._attention_kwargs = attention_kwargs self._interrupt = False self._current_timestep = None @@ -978,6 +1146,18 @@ def __call__( device = self._execution_device # 3. Prepare text embeddings + if system_prompt is not None and prompt is not None: + prompt = self.enhance_prompt( + image=image, + prompt=prompt, + system_prompt=system_prompt, + max_new_tokens=prompt_max_new_tokens, + seed=prompt_enhancement_seed, + generator=generator, + generation_kwargs=prompt_enhancement_kwargs, + device=device, + ) + ( prompt_embeds, prompt_attention_mask, @@ -999,9 +1179,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( - prompt_embeds, additive_attention_mask, additive_mask=True + prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side ) # 4. Prepare latent variables @@ -1023,7 +1205,6 @@ def __call__( raise ValueError( f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." ) - video_sequence_length = latent_num_frames * latent_height * latent_width if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) @@ -1087,7 +1268,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas mu = calculate_shift( - video_sequence_length, + self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_image_seq_len", 1024), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.95), @@ -1116,11 +1297,6 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, - self.vae_spatial_compression_ratio, - self.vae_spatial_compression_ratio, - ) # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop video_coords = self.transformer.rope.prepare_video_coords( latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate @@ -1158,6 +1334,7 @@ def __call__( audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 encoder_attention_mask=connector_attention_mask, audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, @@ -1167,7 +1344,10 @@ def __call__( audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - # rope_interpolation_scale=rope_interpolation_scale, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, attention_kwargs=attention_kwargs, return_dict=False, ) @@ -1175,24 +1355,152 @@ def __call__( noise_pred_audio = noise_pred_audio.float() if self.do_classifier_free_guidance: - noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) - noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( - noise_pred_video_text - noise_pred_video_uncond + noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2) + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_video_uncond_text = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_text, i, self.scheduler + ) + # Use delta formulation as it works more nicely with multiple guidance terms + video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text) + + noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_audio_uncond_text = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler + ) + audio_cfg_delta = (self.audio_guidance_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_text + ) + + # Get positive values from merged CFG inputs in case we need to do other DiT forward passes + if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance: + if i == 0: + # Only split values that remain constant throughout the loop once + video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1] + audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1] + prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1] + + video_pos_ids = video_coords.chunk(2, dim=0)[0] + audio_pos_ids = audio_coords.chunk(2, dim=0)[0] + + # Split values that vary each denoising loop iteration + timestep = timestep.chunk(2, dim=0)[0] + video_timestep = video_timestep.chunk(2, dim=0)[0] + else: + video_cfg_delta = audio_cfg_delta = 0 + + video_prompt_embeds = connector_prompt_embeds + audio_prompt_embeds = connector_audio_prompt_embeds + prompt_attn_mask = connector_attention_mask + + video_pos_ids = video_coords + audio_pos_ids = audio_coords + + noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler) + + if self.do_spatio_temporal_guidance: + noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + isolate_modalities=False, + # Use STG at given blocks to perturb model + spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float() + noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float() + noise_pred_video_uncond_stg = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_stg, i, self.scheduler + ) + noise_pred_audio_uncond_stg = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler + ) + + video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg) + audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg) + else: + video_stg_delta = audio_stg_delta = 0 + + if self.do_modality_isolation_guidance: + noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer( + hidden_states=latents.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=video_prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + sigma=timestep, # Used by LTX-2.3 + encoder_attention_mask=prompt_attn_mask, + audio_encoder_attention_mask=prompt_attn_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_pos_ids, + audio_coords=audio_pos_ids, + # Turn off A2V and V2A cross attn to isolate video and audio modalities + isolate_modalities=True, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=use_cross_timestep, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float() + noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float() + noise_pred_video_uncond_modality = self.convert_velocity_to_x0( + latents, noise_pred_video_uncond_modality, i, self.scheduler ) + noise_pred_audio_uncond_modality = self.convert_velocity_to_x0( + audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler + ) + + video_modality_delta = (self.modality_scale - 1) * ( + noise_pred_video - noise_pred_video_uncond_modality + ) + audio_modality_delta = (self.audio_modality_scale - 1) * ( + noise_pred_audio - noise_pred_audio_uncond_modality + ) + else: + video_modality_delta = audio_modality_delta = 0 + + # Now apply all guidance terms + noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta + noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta - noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) - noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( - noise_pred_audio_text - noise_pred_audio_uncond + # Apply LTX-2.X guidance rescaling + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale ) + else: + noise_pred_video = noise_pred_video_g - if self.guidance_rescale > 0: - # Based on 3.4. in https://huggingface.co/papers/2305.08891 - noise_pred_video = rescale_noise_cfg( - noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale - ) - noise_pred_audio = rescale_noise_cfg( - noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale - ) + if self.audio_guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale + ) + else: + noise_pred_audio = noise_pred_audio_g # compute the previous noisy sample x_t -> x_t-1 noise_pred_video = self._unpack_latents( @@ -1212,6 +1520,10 @@ def __call__( self.transformer_temporal_patch_size, ) + # Convert back to velocity for scheduler + noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler) + noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler) + noise_pred_video = noise_pred_video[:, :, 1:] noise_latents = latents[:, :, 1:] pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] @@ -1249,9 +1561,6 @@ def __call__( self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) audio_latents = self._denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std @@ -1259,6 +1568,9 @@ def __call__( audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) video = latents audio = audio_latents else: @@ -1281,6 +1593,10 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.vae.dtype) video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type="pt").cpu().float().permute(0, 2, 1, 3, 4)