From c988e0ddab275159bd8863717aa26dad741b9c34 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:23:23 +0000 Subject: [PATCH 01/10] do not sample at t=0 (no noise) for flow, recalculate sigmas from timesteps --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 5214cc192..5b094c88d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5574,7 +5574,7 @@ def get_noise_noisy_latents_and_timesteps( t_ref = sigmas sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref) - timesteps = torch.clamp((sigmas * timestep_max).long(), 0, timestep_max) + timesteps = torch.clamp((sigmas * timestep_max).long(), 1, timestep_max) _, huber_c = get_timesteps_and_huber_c( args, 0, @@ -5584,6 +5584,7 @@ def get_noise_noisy_latents_and_timesteps( latents.device, timesteps_override=timesteps, ) + sigmas = timesteps / timestep_max else: min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep From a61e42460ec82eec1445ea069b31bc2e378237b2 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:50:15 +0000 Subject: [PATCH 02/10] convert timesteps to long --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 5b094c88d..059b2559c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5574,7 +5574,7 @@ def get_noise_noisy_latents_and_timesteps( t_ref = sigmas sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref) - timesteps = torch.clamp((sigmas * timestep_max).long(), 1, timestep_max) + timesteps = torch.clamp((sigmas * timestep_max).round().long(), 1, timestep_max) _, huber_c = get_timesteps_and_huber_c( args, 0, From 3ca0c6ed0cf120a52d20080e46e9676ea6fe77f5 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 19:46:06 +0000 Subject: [PATCH 03/10] add apply_jit_pred --- library/train_util.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 059b2559c..d34129de8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5514,6 +5514,60 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) +def apply_jit_pred(args, model_output, latents, noise, zt, timesteps): + t_eps = args.jit_t_eps + timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 + t = timesteps.view(-1, 1, 1, 1) / timestep_max + + if args.jit_loss_type == 'x0': + target = latents + elif args.jit_loss_type == 'v': + # sd3: target = noise - latents + # jit: target = latents - noise + # + # we'll need to flip the target and model output to keep the formulas + # as they are in the paper compatible with sd3 schedule + target = latents - noise + elif args.jit_loss_type == 'eps': + target = noise + + if args.jit_pred_type == 'v': + # see the comment above + model_output = -model_output + + # sd3: t * eps + (1 - t) * latents + # jit: t * latents + (1 - t) * eps + # we'll need to flip the t + t = 1 - t + + if args.jit_pred_type == 'x0': + x0_loss_space = model_output + v_loss_space = (model_output - zt) / (1 - t).clamp_min(t_eps) + eps_loss_space = (zt - t * model_output) / (1 - t).clamp_min(t_eps) + + elif args.jit_pred_type == 'v': + x0_loss_space = (1 - t) * model_output + zt + v_loss_space = model_output + eps_loss_space = zt - t * model_output + + elif args.jit_pred_type == 'eps': + x0_loss_space = (zt - (1 - t) * model_output) / t.clamp_min(t_eps) + v_loss_space = (zt - model_output) / t.clamp_min(t_eps) + eps_loss_space = model_output + + if args.jit_loss_weights is not None: + x0_w, v_w, eps_w = args.jit_loss_weights + loss_space = (x0_w * x0_loss_space + v_w * v_loss_space + eps_w * eps_loss_space) / sum(args.jit_loss_weights) + elif args.jit_loss_type == 'x0': + loss_space = x0_loss_space + elif args.jit_loss_type == 'v': + loss_space = v_loss_space + elif args.jit_loss_type == 'eps': + loss_space = eps_loss_space + + return target, loss_space + + def get_noise_noisy_latents_and_timesteps( args, noise_scheduler, From ef4e6ce010e690bf9f773c2659fb95ec200f2884 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 19:59:55 +0000 Subject: [PATCH 04/10] add jit command line arguments to train_util, use it in sdxl_train.py --- library/train_util.py | 28 ++++++++++++++++++++++++++++ sdxl_train.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d34129de8..c62203191 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3552,6 +3552,34 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # default=None, # help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する", # ) + parser.add_argument( + "--jit_loss_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Loss space to use according to JIT" + ) + parser.add_argument( + "--jit_pred_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Model prediction space to use according to JIT" + ) + parser.add_argument( + "--jit_loss_weights", + type=float, + nargs=3, + metavar=('x0_mult', 'v_mult', 'eps_mult'), + default=None, + help="Perform weighted average across loss types" + ) + parser.add_argument( + "--jit_t_eps", + type=float, + default=1e-2, + help="t_eps to avoid division errors when converting loss types" + ) parser.add_argument( "--multires_noise_discount", type=float, diff --git a/sdxl_train.py b/sdxl_train.py index 7c3fd525d..d8f3015c4 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -930,7 +930,10 @@ def optimizer_hook(parameter: torch.Tensor): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) if args.flow_model: - target = noise - latents + if args.jit_pred_type is not None and args.jit_loss_type is not None: + target, noise_pred = train_util.apply_jit_pred(args, noise_pred, latents, noise, noisy_latents, timesteps) + else: + target = noise - latents elif args.v_parameterization: target = noise_scheduler.get_velocity(latents, noise, timesteps) else: @@ -1296,6 +1299,34 @@ def setup_parser() -> argparse.ArgumentParser: default="auto", help="Select the device to move text encoder to. Only effective if text encoder is not trained.", ) + parser.add_argument( + "--jit_loss_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Loss space to use according to JIT" + ) + parser.add_argument( + "--jit_pred_type", + type=str, + default=None, + choices=['v', 'x0', 'eps'], + help="Model prediction space to use according to JIT" + ) + parser.add_argument( + "--jit_loss_weights", + type=float, + nargs=3, + metavar=('x0_mult', 'v_mult', 'eps_mult'), + default=None, + help="Sum all three loss types according to weighting" + ) + parser.add_argument( + "--jit_t_eps", + type=float, + default=1e-2, + help="t_eps to avoid division errors when converting loss types" + ) return parser From b0976358bed76431422eceafe80f892c38422479 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 20:00:17 +0000 Subject: [PATCH 05/10] a --- sdxl_train.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index d8f3015c4..25948b3e0 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1299,34 +1299,6 @@ def setup_parser() -> argparse.ArgumentParser: default="auto", help="Select the device to move text encoder to. Only effective if text encoder is not trained.", ) - parser.add_argument( - "--jit_loss_type", - type=str, - default=None, - choices=['v', 'x0', 'eps'], - help="Loss space to use according to JIT" - ) - parser.add_argument( - "--jit_pred_type", - type=str, - default=None, - choices=['v', 'x0', 'eps'], - help="Model prediction space to use according to JIT" - ) - parser.add_argument( - "--jit_loss_weights", - type=float, - nargs=3, - metavar=('x0_mult', 'v_mult', 'eps_mult'), - default=None, - help="Sum all three loss types according to weighting" - ) - parser.add_argument( - "--jit_t_eps", - type=float, - default=1e-2, - help="t_eps to avoid division errors when converting loss types" - ) return parser From 443e568baa026060aaf0701ab1c2ef0cd682d367 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Fri, 9 Jan 2026 20:19:59 +0000 Subject: [PATCH 06/10] fix noise scheduler thing --- library/train_util.py | 2 +- sdxl_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c62203191..2658be236 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5542,7 +5542,7 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) -def apply_jit_pred(args, model_output, latents, noise, zt, timesteps): +def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): t_eps = args.jit_t_eps timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 t = timesteps.view(-1, 1, 1, 1) / timestep_max diff --git a/sdxl_train.py b/sdxl_train.py index 25948b3e0..69f0a5ab8 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -931,7 +931,7 @@ def optimizer_hook(parameter: torch.Tensor): if args.flow_model: if args.jit_pred_type is not None and args.jit_loss_type is not None: - target, noise_pred = train_util.apply_jit_pred(args, noise_pred, latents, noise, noisy_latents, timesteps) + target, noise_pred = train_util.apply_jit_pred(args, noise_scheduler, noise_pred, latents, noise, noisy_latents, timesteps) else: target = noise - latents elif args.v_parameterization: From 3da74093c31668f3ff39532f8d833505acf44cde Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Sat, 10 Jan 2026 23:09:16 +0300 Subject: [PATCH 07/10] refactor & make sure all loss spaces are completely separate --- library/train_util.py | 44 +++++++++++++++++++++++++++++++------------ sdxl_train.py | 5 +++++ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2658be236..12cf5fb4e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5542,38 +5542,42 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) -def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): - t_eps = args.jit_t_eps - timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 - t = timesteps.view(-1, 1, 1, 1) / timestep_max - - if args.jit_loss_type == 'x0': +def jit_get_target(target_type, latents, noise): + if target_type == 'x0': target = latents - elif args.jit_loss_type == 'v': + elif target_type == 'v': # sd3: target = noise - latents # jit: target = latents - noise # # we'll need to flip the target and model output to keep the formulas # as they are in the paper compatible with sd3 schedule target = latents - noise - elif args.jit_loss_type == 'eps': + elif target_type == 'eps': target = noise - if args.jit_pred_type == 'v': - # see the comment above - model_output = -model_output + return target + + +def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): + t_eps = args.jit_t_eps + timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 + t = timesteps.view(-1, 1, 1, 1) / timestep_max # sd3: t * eps + (1 - t) * latents # jit: t * latents + (1 - t) * eps # we'll need to flip the t t = 1 - t + if args.jit_pred_type == 'x0': x0_loss_space = model_output v_loss_space = (model_output - zt) / (1 - t).clamp_min(t_eps) eps_loss_space = (zt - t * model_output) / (1 - t).clamp_min(t_eps) elif args.jit_pred_type == 'v': + # see the comment for jit_get_target, jit's v target is flipped + model_output = -model_output + x0_loss_space = (1 - t) * model_output + zt v_loss_space = model_output eps_loss_space = zt - t * model_output @@ -5583,18 +5587,34 @@ def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, time v_loss_space = (zt - model_output) / t.clamp_min(t_eps) eps_loss_space = model_output + if args.jit_loss_weights is not None: + # we can concatenate the targets and loss spaces and let the loss function handle that x0_w, v_w, eps_w = args.jit_loss_weights - loss_space = (x0_w * x0_loss_space + v_w * v_loss_space + eps_w * eps_loss_space) / sum(args.jit_loss_weights) + loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) + target = torch.cat([jit_get_target(t, latents, noise) for t in ['x0', 'v', 'eps']]) + elif args.jit_loss_type == 'x0': + target = jit_get_target('x0', latents, noise) loss_space = x0_loss_space + elif args.jit_loss_type == 'v': + target = jit_get_target('v', latents, noise) loss_space = v_loss_space + elif args.jit_loss_type == 'eps': + target = jit_get_target('eps', latents, noise) loss_space = eps_loss_space return target, loss_space +def apply_jit_weighting(args, loss): + weights = torch.tensor(args.jit_loss_weights, device=loss.device) + n = weights.shape[0] # 3 for x0, v, eps + b, c, h, w = loss.shape + b //= n + + return (loss.view(n, b, c, h, w) * weights.view(n, 1, 1, 1, 1)).sum(dim=0) / weights.sum() def get_noise_noisy_latents_and_timesteps( args, diff --git a/sdxl_train.py b/sdxl_train.py index 69f0a5ab8..2505d3ab9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -945,10 +945,15 @@ def optimizer_hook(parameter: torch.Tensor): or args.v_pred_like_loss or args.debiased_estimation_loss or args.masked_loss + or args.jit_loss_weights ): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + + if args.jit_loss_weights is not None: + train_util.apply_jit_weighting(args, loss) + if args.contrastive_flow_matching and latents.size(0) > 1: negative_latents = latents.roll(1, 0) negative_noise = noise.roll(1, 0) From d0ac5b719f3f63cebcaafd67375954c0a1b99a51 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Sun, 11 Jan 2026 03:59:29 +0300 Subject: [PATCH 08/10] refactor --- library/train_util.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 12cf5fb4e..270deb8e6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5542,22 +5542,6 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) -def jit_get_target(target_type, latents, noise): - if target_type == 'x0': - target = latents - elif target_type == 'v': - # sd3: target = noise - latents - # jit: target = latents - noise - # - # we'll need to flip the target and model output to keep the formulas - # as they are in the paper compatible with sd3 schedule - target = latents - noise - elif target_type == 'eps': - target = noise - - return target - - def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): t_eps = args.jit_t_eps timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 @@ -5568,6 +5552,15 @@ def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, time # we'll need to flip the t t = 1 - t + x0_target = latents + # sd3: target = noise - latents + # jit: target = latents - noise + # + # we'll need to flip the target and model output to keep the formulas + # as they are in the paper compatible with sd3 schedule + v_target = latents - noise + eps_target = noise + if args.jit_pred_type == 'x0': x0_loss_space = model_output @@ -5575,7 +5568,7 @@ def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, time eps_loss_space = (zt - t * model_output) / (1 - t).clamp_min(t_eps) elif args.jit_pred_type == 'v': - # see the comment for jit_get_target, jit's v target is flipped + # see the comment above, jit's v target is flipped model_output = -model_output x0_loss_space = (1 - t) * model_output + zt @@ -5592,18 +5585,18 @@ def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, time # we can concatenate the targets and loss spaces and let the loss function handle that x0_w, v_w, eps_w = args.jit_loss_weights loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) - target = torch.cat([jit_get_target(t, latents, noise) for t in ['x0', 'v', 'eps']]) + target = torch.cat((x0_target, v_target, eps_target)) elif args.jit_loss_type == 'x0': - target = jit_get_target('x0', latents, noise) + target = x0_target loss_space = x0_loss_space elif args.jit_loss_type == 'v': - target = jit_get_target('v', latents, noise) + target = v_target loss_space = v_loss_space elif args.jit_loss_type == 'eps': - target = jit_get_target('eps', latents, noise) + target = eps_target loss_space = eps_loss_space return target, loss_space From 5c72936e9ce29e9c5c695ce5c03c4d914a6ee12a Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Sun, 11 Jan 2026 06:56:16 +0300 Subject: [PATCH 09/10] implement implicit scale for x0 pred --- library/train_util.py | 97 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 270deb8e6..078b26685 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3578,7 +3578,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--jit_t_eps", type=float, default=1e-2, - help="t_eps to avoid division errors when converting loss types" + help="t_eps to avoid division errors when converting loss types, ignored if implicit scale is used" + ) + parser.add_argument( + "--jit_implicit_scale", + action="store_true", + help="Calculate loss in a scaled space where possible to avoid division instability" ) parser.add_argument( "--multires_noise_discount", @@ -5542,13 +5547,99 @@ def _scipy_assignment(cost: torch.Tensor): return cost, (row, col) +def apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps): + # here, we attempt to convert loss/pred types without involving division + # MSE(X * v_target, X * v_pred) ~= MSE(v_target, v_pred) + # downside/quirk: it changes the loss weighting which may be undesireable + + t_eps = args.jit_t_eps + timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 + t = timesteps.view(-1, 1, 1, 1) / timestep_max + + # sd3: t * noise + (1 - t) * latents + # jit: t * latents + (1 - t) * noise + # we'll need to flip the t + t = 1 - t + + if args.jit_pred_type == 'x0': + # stable as is + x0_target = latents + x0_loss_space = model_output + + # sd3: target = noise - latents + # jit: target = latents - noise + # + # we'll need to flip the target to keep the formulas compatible with sd3 schedule + # + # v_target = latents - noise + # v_loss_space = (model_output - zt) / (1 - t) + # + # implicity: let's multiply both sides by (1 - t) + # + # (1 - t) * v_target = (1 - t) * (latents - noise) + # (1 - t) * v_loss_space = model_output - zt + # + # consider flow eq: zt = t * latents + [(1 - t) * noise] -> [(1 - t) * noise] = zt - t * latents + # + # consider: (1 - t) * (latents - noise) + # = (1 - t) * latents - [(1 - t) * noise] + # + # substitute: + # = (1 - t) * latents - (zt - t * latents) + # = latents - zt + # + v_target = latents - zt + v_loss_space = model_output - zt + + # similarly, for eps: + # + # eps_target = noise + # eps_loss_space = (zt - t * model_output) / (1 - t) + # + # (1 - t) * eps_target = (1 - t) * noise = zt - t * latents + # (1 - t) * eps_loss_space = (1 - t) * (zt - t * model_output) / (1 - t) + # + eps_target = zt - t * latents + eps_loss_space = zt - t * model_output + + elif args.jit_pred_type == 'v': + raise NotImplementedError + + elif args.jit_pred_type == 'eps': + raise NotImplementedError + + + if args.jit_loss_weights is not None: + # we can concatenate the targets and loss spaces and let the loss function handle that + x0_w, v_w, eps_w = args.jit_loss_weights + loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) + target = torch.cat((x0_target, v_target, eps_target)) + + elif args.jit_loss_type == 'x0': + target = x0_target + loss_space = x0_loss_space + + elif args.jit_loss_type == 'v': + target = v_target + loss_space = v_loss_space + + elif args.jit_loss_type == 'eps': + target = eps_target + loss_space = eps_loss_space + + return target, loss_space + + def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, timesteps): + if args.jit_implicit_scale: + return apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, zt, timesteps) + t_eps = args.jit_t_eps timestep_max = noise_scheduler.config.num_train_timesteps - 1 if args.max_timestep is None else args.max_timestep - 1 t = timesteps.view(-1, 1, 1, 1) / timestep_max - # sd3: t * eps + (1 - t) * latents - # jit: t * latents + (1 - t) * eps + # sd3: t * noise + (1 - t) * latents + # jit: t * latents + (1 - t) * noise # we'll need to flip the t t = 1 - t From 51dc18d8eab6f80b89a6efae1e5d74558a528ab3 Mon Sep 17 00:00:00 2001 From: adsfssd <240785495+adsfssd@users.noreply.github.com> Date: Sun, 11 Jan 2026 07:05:51 +0300 Subject: [PATCH 10/10] cleanup --- library/train_util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 078b26685..11abdce70 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5611,7 +5611,6 @@ def apply_jit_pred_implicit(args, noise_scheduler, model_output, latents, noise, if args.jit_loss_weights is not None: # we can concatenate the targets and loss spaces and let the loss function handle that - x0_w, v_w, eps_w = args.jit_loss_weights loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) target = torch.cat((x0_target, v_target, eps_target)) @@ -5674,7 +5673,6 @@ def apply_jit_pred(args, noise_scheduler, model_output, latents, noise, zt, time if args.jit_loss_weights is not None: # we can concatenate the targets and loss spaces and let the loss function handle that - x0_w, v_w, eps_w = args.jit_loss_weights loss_space = torch.cat((x0_loss_space, v_loss_space, eps_loss_space)) target = torch.cat((x0_target, v_target, eps_target))