Skip to content

关于gradient_step_per_epoch的bug #85

@Xilluill

Description

@Xilluill

Hi,您的开源框架对我的研究帮助很大,最近我在使用nft的时候,发现了一些疑惑,也许是bug
在超参数初始化的时候

sample_num_per_iteration = world_size * self.per_device_batch_size
step = (sample_num_per_iteration * self.gradient_step_per_epoch) // math.gcd(self.group_size, sample_num_per_iteration)
new_m = (self.unique_sample_num_per_epoch + step - 1) // step * step
if new_m != self.unique_sample_num_per_epoch:
logger.warning(
f"Adjusted `unique_sample_num` from {self.unique_sample_num_per_epoch} to {new_m} "
f"to make sure `unique_sample_num`*`group_size` is multiple of `batch_size`*`num_replicas`*`gradient_step_per_epoch` for even distribution."
)
self.unique_sample_num_per_epoch = new_m
self.num_batches_per_epoch = (self.unique_sample_num_per_epoch * self.group_size) // sample_num_per_iteration
self.gradient_accumulation_steps = max(1, self.num_batches_per_epoch // self.gradient_step_per_epoch)

我理解的是num_batches_per_epoch是计算出来,每一个epoch,每一张卡上计算多少个batch,batch可能对应多张图片。gradient_step_per_epoch是我们认为每一个epoch应该更新多少次,默认是1次,即所有batch算完后,更新1次。
然而在nft的训练循环中,还存在一个参数,即每个batch还会被renoise和计算 num_train_timesteps 次
with self.autocast():
for batch in tqdm(
sample_batches,
total=len(sample_batches),
desc=f'Epoch {self.epoch} Training',
position=0,
disable=not self.show_progress_bar,
):
# Retrieve pre-computed data
batch_size = batch['all_latents'].shape[0]
clean_latents = batch['all_latents'][:, -1]
all_timesteps = batch['_all_timesteps']
all_random_noise = batch['_all_random_noise']
old_v_pred_list = batch['_old_v_pred_list']
# Iterate through timesteps
for t_idx in tqdm(
range(self.num_train_timesteps),
desc=f'Epoch {self.epoch} Timestep',
position=1,
leave=False,
disable=not self.show_progress_bar,
):
with self.accelerator.accumulate(*self.adapter.trainable_components):
# 1. Prepare inputs
t_flat = all_timesteps[t_idx] # (B,)

而在这个双层循环里,才是accelerate的梯度累计的管理器,因此,例如在num_train_timesteps=2时,这个epoch会更新2次,和我们的超参设置就不一样了,且这个2次,是将原定的所有batch拆成了前半部分和后半部分分别更新。
请问是我的理解有误,还是这里确实有个bug?还是是特意的设定?
期待您的回复!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions