Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 186 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3552,6 +3552,39 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# default=None,
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
# )
parser.add_argument(
"--jit_loss_type",
type=str,
default=None,
choices=['v', 'x0', 'eps'],
help="Loss space to use according to JIT"
)
parser.add_argument(
"--jit_pred_type",
type=str,
default=None,
choices=['v', 'x0', 'eps'],
help="Model prediction space to use according to JIT"
)
parser.add_argument(
"--jit_loss_weights",
type=float,
nargs=3,
metavar=('x0_mult', 'v_mult', 'eps_mult'),
default=None,
help="Perform weighted average across loss types"
)
parser.add_argument(
"--jit_t_eps",
type=float,
default=1e-2,
help="t_eps to avoid division errors when converting loss types, ignored if implicit scale is used"
)
parser.add_argument(
"--jit_implicit_scale",
action="store_true",
help="Calculate loss in a scaled space where possible to avoid division instability"
)
parser.add_argument(
"--multires_noise_discount",
type=float,
Expand Down Expand Up @@ -5514,6 +5547,157 @@ def _scipy_assignment(cost: torch.Tensor):
return cost, (row, col)


def apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps):
# here, we attempt to convert loss/pred types without involving division
# MSE(X * v_target, X * v_pred) ~= MSE(v_target, v_pred)
# downside/quirk: it changes the loss weighting which may be undesireable

t_eps = args.jit_t_eps
timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1
t = timesteps.view(-1, 1, 1, 1) / timestep_max

# sd3: t * noise + (1 - t) * latents
# jit: t * latents + (1 - t) * noise
# we'll need to flip the t
t = 1 - t

if args.jit_pred_type == 'x0':
# stable as is
x0_target = latents
x0_loss_space = model_output

# sd3: target = noise - latents
# jit: target = latents - noise
#
# we'll need to flip the target to keep the formulas compatible with sd3 schedule
#
# v_target = latents - noise
# v_loss_space = (model_output - zt) / (1 - t)
#
# implicity: let's multiply both sides by (1 - t)
#
# (1 - t) * v_target = (1 - t) * (latents - noise)
# (1 - t) * v_loss_space = model_output - zt
#
# consider flow eq: zt = t * latents + [(1 - t) * noise] -> [(1 - t) * noise] = zt - t * latents
#
# consider: (1 - t) * (latents - noise)
# = (1 - t) * latents - [(1 - t) * noise]
#
# substitute:
# = (1 - t) * latents - (zt - t * latents)
# = latents - zt
#
v_target = latents - zt
v_loss_space = model_output - zt

# similarly, for eps:
#
# eps_target = noise
# eps_loss_space = (zt - t * model_output) / (1 - t)
#
# (1 - t) * eps_target = (1 - t) * noise = zt - t * latents
# (1 - t) * eps_loss_space = (1 - t) * (zt - t * model_output) / (1 - t)
#
eps_target = zt - t * latents
eps_loss_space = zt - t * model_output

elif args.jit_pred_type == 'v':
raise NotImplementedError

elif args.jit_pred_type == 'eps':
raise NotImplementedError


if args.jit_loss_weights is not None:
# we can concatenate the targets and loss spaces and let the loss function handle that
loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space))
target = torch.cat((x0_target, v_target, eps_target))

elif args.jit_loss_type == 'x0':
target = x0_target
loss_space = x0_loss_space

elif args.jit_loss_type == 'v':
target = v_target
loss_space = v_loss_space

elif args.jit_loss_type == 'eps':
target = eps_target
loss_space = eps_loss_space

return target, loss_space


def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps):
if args.jit_implicit_scale:
return apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps)

t_eps = args.jit_t_eps
timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1
t = timesteps.view(-1, 1, 1, 1) / timestep_max

# sd3: t * noise + (1 - t) * latents
# jit: t * latents + (1 - t) * noise
# we'll need to flip the t
t = 1 - t

x0_target = latents
# sd3: target = noise - latents
# jit: target = latents - noise
#
# we'll need to flip the target and model output to keep the formulas
# as they are in the paper compatible with sd3 schedule
v_target = latents - noise
eps_target = noise


if args.jit_pred_type == 'x0':
x0_loss_space = model_output
v_loss_space = (model_output - zt) / (1 - t).clamp_min(t_eps)
eps_loss_space = (zt - t * model_output) / (1 - t).clamp_min(t_eps)

elif args.jit_pred_type == 'v':
# see the comment above, jit's v target is flipped
model_output = -model_output

x0_loss_space = (1 - t) * model_output + zt
v_loss_space = model_output
eps_loss_space = zt - t * model_output

elif args.jit_pred_type == 'eps':
x0_loss_space = (zt - (1 - t) * model_output) / t.clamp_min(t_eps)
v_loss_space = (zt - model_output) / t.clamp_min(t_eps)
eps_loss_space = model_output


if args.jit_loss_weights is not None:
# we can concatenate the targets and loss spaces and let the loss function handle that
loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space))
target = torch.cat((x0_target, v_target, eps_target))

elif args.jit_loss_type == 'x0':
target = x0_target
loss_space = x0_loss_space

elif args.jit_loss_type == 'v':
target = v_target
loss_space = v_loss_space

elif args.jit_loss_type == 'eps':
target = eps_target
loss_space = eps_loss_space

return target, loss_space

def apply_jit_weighting(args, loss):
weights = torch.tensor(args.jit_loss_weights, device=loss.device)
n = weights.shape[0] # 3 for x0, v, eps
b, c, h, w = loss.shape
b //= n

return (loss.view(n, b, c, h, w) * weights.view(n, 1, 1, 1, 1)).sum(dim=0) / weights.sum()

def get_noise_noisy_latents_and_timesteps(
args,
noise_scheduler,
Expand Down Expand Up @@ -5574,7 +5758,7 @@ def get_noise_noisy_latents_and_timesteps(
t_ref = sigmas
sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref)

timesteps = torch.clamp((sigmas * timestep_max).long(), 0, timestep_max)
timesteps = torch.clamp((sigmas * timestep_max).round().long(), 1, timestep_max)
_, huber_c = get_timesteps_and_huber_c(
args,
0,
Expand All @@ -5584,6 +5768,7 @@ def get_noise_noisy_latents_and_timesteps(
latents.device,
timesteps_override=timesteps,
)
sigmas = timesteps / timestep_max
else:
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
Expand Down
10 changes: 9 additions & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,10 @@ def optimizer_hook(parameter: torch.Tensor):
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)

if args.flow_model:
target = noise - latents
if args.jit_pred_type is not None and args.jit_loss_type is not None:
target, noise_pred = train_util.apply_jit_pred(args, noise_scheduler, noise_pred, latents, noise, noisy_latents, timesteps)
else:
target = noise - latents
elif args.v_parameterization:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
Expand All @@ -942,10 +945,15 @@ def optimizer_hook(parameter: torch.Tensor):
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
or args.jit_loss_weights
):
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)

if args.jit_loss_weights is not None:
train_util.apply_jit_weighting(args, loss)

if args.contrastive_flow_matching and latents.size(0) > 1:
negative_latents = latents.roll(1, 0)
negative_noise = noise.roll(1, 0)
Expand Down