diff --git a/library/train_util.py b/library/train_util.py index 5214cc192..11abdce70 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3552,6 +3552,39 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # default=None, # help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する", # ) + parser.add_argument( + "--jit_loss_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Loss space to use according to JIT" + ) + parser.add_argument( + "--jit_pred_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Model prediction space to use according to JIT" + ) + parser.add_argument( + "--jit_loss_weights", + type=float, + nargs=3, + metavar=('x0_mult', 'v_mult', 'eps_mult'), + default=None, + help="Perform weighted average across loss types" + ) + parser.add_argument( + "--jit_t_eps", + type=float, + default=1e-2, + help="t_eps to avoid division errors when converting loss types, ignored if implicit scale is used" + ) + parser.add_argument( + "--jit_implicit_scale", + action="store_true", + help="Calculate loss in a scaled space where possible to avoid division instability" + ) parser.add_argument( "--multires_noise_discount", type=float, @@ -5514,6 +5547,157 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) +def apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps): + # here, we attempt to convert loss/pred types without involving division + # MSE(X * v_target, X * v_pred) ~= MSE(v_target, v_pred) + # downside/quirk: it changes the loss weighting which may be undesireable + + t_eps = args.jit_t_eps + timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 + t = timesteps.view(-1, 1, 1, 1) / timestep_max + + # sd3: t * noise + (1 - t) * latents + # jit: t * latents + (1 - t) * noise + # we'll need to flip the t + t = 1 - t + + if args.jit_pred_type == 'x0': + # stable as is + x0_target = latents + x0_loss_space = model_output + + # sd3: target = noise - latents + # jit: target = latents - noise + # + # we'll need to flip the target to keep the formulas compatible with sd3 schedule + # + # v_target = latents - noise + # v_loss_space = (model_output - zt) / (1 - t) + # + # implicity: let's multiply both sides by (1 - t) + # + # (1 - t) * v_target = (1 - t) * (latents - noise) + # (1 - t) * v_loss_space = model_output - zt + # + # consider flow eq: zt = t * latents + [(1 - t) * noise] -> [(1 - t) * noise] = zt - t * latents + # + # consider: (1 - t) * (latents - noise) + # = (1 - t) * latents - [(1 - t) * noise] + # + # substitute: + # = (1 - t) * latents - (zt - t * latents) + # = latents - zt + # + v_target = latents - zt + v_loss_space = model_output - zt + + # similarly, for eps: + # + # eps_target = noise + # eps_loss_space = (zt - t * model_output) / (1 - t) + # + # (1 - t) * eps_target = (1 - t) * noise = zt - t * latents + # (1 - t) * eps_loss_space = (1 - t) * (zt - t * model_output) / (1 - t) + # + eps_target = zt - t * latents + eps_loss_space = zt - t * model_output + + elif args.jit_pred_type == 'v': + raise NotImplementedError + + elif args.jit_pred_type == 'eps': + raise NotImplementedError + + + if args.jit_loss_weights is not None: + # we can concatenate the targets and loss spaces and let the loss function handle that + loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) + target = torch.cat((x0_target, v_target, eps_target)) + + elif args.jit_loss_type == 'x0': + target = x0_target + loss_space = x0_loss_space + + elif args.jit_loss_type == 'v': + target = v_target + loss_space = v_loss_space + + elif args.jit_loss_type == 'eps': + target = eps_target + loss_space = eps_loss_space + + return target, loss_space + + +def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): + if args.jit_implicit_scale: + return apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps) + + t_eps = args.jit_t_eps + timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 + t = timesteps.view(-1, 1, 1, 1) / timestep_max + + # sd3: t * noise + (1 - t) * latents + # jit: t * latents + (1 - t) * noise + # we'll need to flip the t + t = 1 - t + + x0_target = latents + # sd3: target = noise - latents + # jit: target = latents - noise + # + # we'll need to flip the target and model output to keep the formulas + # as they are in the paper compatible with sd3 schedule + v_target = latents - noise + eps_target = noise + + + if args.jit_pred_type == 'x0': + x0_loss_space = model_output + v_loss_space = (model_output - zt) / (1 - t).clamp_min(t_eps) + eps_loss_space = (zt - t * model_output) / (1 - t).clamp_min(t_eps) + + elif args.jit_pred_type == 'v': + # see the comment above, jit's v target is flipped + model_output = -model_output + + x0_loss_space = (1 - t) * model_output + zt + v_loss_space = model_output + eps_loss_space = zt - t * model_output + + elif args.jit_pred_type == 'eps': + x0_loss_space = (zt - (1 - t) * model_output) / t.clamp_min(t_eps) + v_loss_space = (zt - model_output) / t.clamp_min(t_eps) + eps_loss_space = model_output + + + if args.jit_loss_weights is not None: + # we can concatenate the targets and loss spaces and let the loss function handle that + loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) + target = torch.cat((x0_target, v_target, eps_target)) + + elif args.jit_loss_type == 'x0': + target = x0_target + loss_space = x0_loss_space + + elif args.jit_loss_type == 'v': + target = v_target + loss_space = v_loss_space + + elif args.jit_loss_type == 'eps': + target = eps_target + loss_space = eps_loss_space + + return target, loss_space + +def apply_jit_weighting(args, loss): + weights = torch.tensor(args.jit_loss_weights, device=loss.device) + n = weights.shape[0] # 3 for x0, v, eps + b, c, h, w = loss.shape + b //= n + + return (loss.view(n, b, c, h, w) * weights.view(n, 1, 1, 1, 1)).sum(dim=0) / weights.sum() + def get_noise_noisy_latents_and_timesteps( args, noise_scheduler, @@ -5574,7 +5758,7 @@ def get_noise_noisy_latents_and_timesteps( t_ref = sigmas sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref) - timesteps = torch.clamp((sigmas * timestep_max).long(), 0, timestep_max) + timesteps = torch.clamp((sigmas * timestep_max).round().long(), 1, timestep_max) _, huber_c = get_timesteps_and_huber_c( args, 0, @@ -5584,6 +5768,7 @@ def get_noise_noisy_latents_and_timesteps( latents.device, timesteps_override=timesteps, ) + sigmas = timesteps / timestep_max else: min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep diff --git a/sdxl_train.py b/sdxl_train.py index 7c3fd525d..2505d3ab9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -930,7 +930,10 @@ def optimizer_hook(parameter: torch.Tensor): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) if args.flow_model: - target = noise - latents + if args.jit_pred_type is not None and args.jit_loss_type is not None: + target, noise_pred = train_util.apply_jit_pred(args, noise_scheduler, noise_pred, latents, noise, noisy_latents, timesteps) + else: + target = noise - latents elif args.v_parameterization: target = noise_scheduler.get_velocity(latents, noise, timesteps) else: @@ -942,10 +945,15 @@ def optimizer_hook(parameter: torch.Tensor): or args.v_pred_like_loss or args.debiased_estimation_loss or args.masked_loss + or args.jit_loss_weights ): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + + if args.jit_loss_weights is not None: + train_util.apply_jit_weighting(args, loss) + if args.contrastive_flow_matching and latents.size(0) > 1: negative_latents = latents.roll(1, 0) negative_noise = noise.roll(1, 0)