本项目是基于TongTong313/rectified-flow项目的二次开发,旨在将原本复杂的文件结构和命令行操作简化为单个Jupyter Notebook,实现开箱即用的教学和实验环境。
Rectified Flow是一种新兴的生成模型方法,通过学习从噪声到数据的直线轨迹,实现了高质量的样本生成。本项目提供了完整的1-Rectified Flow和2-Rectified Flow训练与推理流程,特别适合学习和研究生成模型的相关原理。
- 📓 单文件实现:所有功能集成在一个Jupyter Notebook中,便于理解和修改
- 🚀 开箱即用:无需复杂的命令行操作,直接线性运行notebook的各单元格即可
- 📚 教学友好:代码注释详细,配合中文说明,适合学习使用
- 🔧 简化配置:去除复杂的工程优化,专注于算法核心实现
- 💻 低硬件要求:MNIST数据集小巧易用,支持消费级显卡训练
torch 2.7.1
torchdiffeq 0.2.5
torchmetrics 1.8.2
torchvision 0.22.1
numpy 2.0.1
matplotlib 3.10.0
opencv-python 4.12.0.88
tqdm 4.67.1
scipy 1.15.3
# 创建并激活conda环境
conda create -n env_name python=3.10
conda activate env_name
# 安装PyTorch(根据你的CUDA版本调整)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# 安装其他依赖
pip install torchdiffeq torchmetrics matplotlib opencv-python tqdm scipyRectified-flow-notebook/
├── rectified_flow.ipynb # 主要的notebook文件
├── checkpoints/ # 训练好的模型权重
│ ├── 1rf/ # 1-Rectified Flow模型
│ └── 2rf/ # 2-Rectified Flow模型
├── data/ # MNIST数据集
├── results/ # 生成的结果图片
│ ├── 1rf/ # 1-RF生成结果
│ ├── 2rf/ # 2-RF生成结果
│ └── comparison.png # 对比图
└── README.md # 本文件
打开rectified_flow.ipynb,使用env_name中的解释器作为内核,按顺序执行所有单元格即可完成训练和推理。
- 参数设置:配置训练超参数和模型参数
- 数据加载:自动下载和加载MNIST数据集
- 模型定义:MiniUNet网络结构
- 1-Rectified Flow训练:基础Rectified Flow模型训练
- 2-Rectified Flow训练:Reflow过程,进一步优化生成轨迹
- 推理和评估:生成样本并计算IS/FID指标
- 结果可视化:展示生成效果
本项目使用torchdiffeq库进行ODE求解,这是PyTorch生态中专门用于神经微分方程的库。
- 使用欧拉法进行数值求解
- 固定步数采样(1-RF: 50步,2-RF: 5步)
开发者可以根据需要修改ODE求解器的配置:
from torchdiffeq import odeint
# 当前实现(欧拉法)
t_span = torch.linspace(0, 1, steps).to(device)
trajectory = odeint(rf.vector_field, x_0, t_span, method='euler')
x_t = trajectory[-1]
# 可选择高阶自适应步数求解器
t = torch.tensor([0.0, 1.0]).to(device)
trajectory = odeint(rf.vector_field, x_0, t, method='dopri5') # Runge-Kutta 4(5)
trajectory = odeint(rf.vector_field, x_0, t, method='dopri8') # 8阶Runge-Kutta
trajectory = odeintiffeq(rf.vector_field, x_0, t, method='bosh3') # Bosh 3阶
x_t = trajectory[-1]- 欧拉法:速度快但精度较低,适合快速原型
- 高阶自适应步数方法:精度高但计算量大,适合高质量生成
理论上,通过Reflow过程,2-Rectified Flow能够在更少的步数内达到与1-Rectified Flow相当的生成质量,显著提升推理效率。
但由于本人对reflow过程准备的配对数据量似乎较小,且没有精细调优超参数,故2-Rectified Flow在5步下的欧拉法生成结果质量一般。开发者可自行进一步调整超参数并测试性能。
- 从标准正态分布采样噪声
$$z_0 \sim \mathcal{N}(0, I)$$ - 从真实数据采样
$$z_1 \sim p_{data}(x)$$ - 线性插值:
$$z_t = (1-t) \cdot z_0 + t \cdot z_1$$ - 学习向量场:
$$v_\theta(z_t, t) = z_1 - z_0$$
- 使用训练好的1-RF模型生成配对数据
- 重新训练模型,学习更直的轨迹
- 实现更少步数的推理
- 本项目基于TongTong313/rectified-flow进行二次开发,感谢原作者童佬提供的详细代码和视频教程
- 核心算法参考论文:
本项目采用 MIT 许可证,详情如下:
MIT License
Copyright (c) 2025 Rectified Flow Notebook Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
欢迎提交Issue和Pull Request来改进本项目。
注意:本项目仅供学习和研究使用,如需用于生产环境,请根据实际需求进行优化。
