From 64beca70f90fe66c76d1203f01255aae9b680219 Mon Sep 17 00:00:00 2001 From: Eric Sung Date: Sun, 25 Jan 2026 01:34:13 -0800 Subject: [PATCH] Fix: shape-safe TeaCache residuals + auto-reset + updated LTXBaseModel for latest ComfyUI --- nodes.py | 460 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 302 insertions(+), 158 deletions(-) diff --git a/nodes.py b/nodes.py index 04bff4a..bbc2d20 100644 --- a/nodes.py +++ b/nodes.py @@ -9,7 +9,7 @@ from unittest.mock import patch from comfy.ldm.flux.layers import timestep_embedding, apply_mod -from comfy.ldm.lightricks.model import precompute_freqs_cis +from comfy.ldm.lightricks.model import LTXBaseModel from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords from comfy.ldm.wan.model import sinusoidal_embedding_1d @@ -33,12 +33,66 @@ "wan2.1_i2v_720p_14B_ret_mode": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], } + def poly1d(coefficients, x): result = torch.zeros_like(x) for i, coeff in enumerate(coefficients): result += coeff * (x ** (len(coefficients) - 1 - i)) return result + +# ----------------------------- +# TeaCache safety helpers +# ----------------------------- +def _tc_shape_sig(t: torch.Tensor): + # Cheap signature to detect shape/dtype changes across runs + return (tuple(t.shape), str(t.dtype)) + + +def _tc_safe_add_residual(target: torch.Tensor, residual: Optional[torch.Tensor]): + """ + Adds cached residual only if it exists and matches target shape. + Returns True if applied, False if skipped. + """ + if residual is None: + return False + if tuple(residual.shape) != tuple(target.shape): + return False + target += residual.to(device=target.device, dtype=target.dtype) + return True + + +def _tc_ensure_state(self_obj, keys=(0, 1)): + if not hasattr(self_obj, 'teacache_state'): + self_obj.teacache_state = {k: { + 'should_calc': True, + 'accumulated_rel_l1_distance': 0, + 'previous_modulated_input': None, + 'previous_residual': None, + 'shape_sig': None, # NEW + } for k in keys} + else: + for k in keys: + if k not in self_obj.teacache_state: + self_obj.teacache_state[k] = { + 'should_calc': True, + 'accumulated_rel_l1_distance': 0, + 'previous_modulated_input': None, + 'previous_residual': None, + 'shape_sig': None, + } + else: + self_obj.teacache_state[k].setdefault('shape_sig', None) + + +def _tc_reset_cache_entry(cache: dict): + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + cache['previous_modulated_input'] = None + cache['previous_residual'] = None + cache['shape_sig'] = None + + def teacache_flux_forward( self, img: Tensor, @@ -48,7 +102,7 @@ def teacache_flux_forward( timesteps: Tensor, y: Tensor, guidance: Tensor = None, - control = None, + control=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: @@ -60,7 +114,7 @@ def teacache_flux_forward( if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - + if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -71,7 +125,7 @@ def teacache_flux_forward( if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) if img_ids is not None: @@ -93,13 +147,17 @@ def teacache_flux_forward( self.accumulated_rel_l1_distance = 0 else: try: - self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())).abs() + self.accumulated_rel_l1_distance += poly1d( + coefficients, + ((modulated_inp - self.previous_modulated_input).abs().mean() + / self.previous_modulated_input.abs().mean()) + ).abs() if self.accumulated_rel_l1_distance < rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 - except: + except Exception: should_calc = True self.accumulated_rel_l1_distance = 0 @@ -109,7 +167,8 @@ def teacache_flux_forward( should_calc = True if not should_calc: - img += self.previous_residual.to(img.device) + # Safety: only add if shapes match + _tc_safe_add_residual(img, getattr(self, 'previous_residual', None)) else: ori_img = img.to(cache_device) for i, block in enumerate(self.double_blocks): @@ -117,28 +176,28 @@ def teacache_flux_forward( def block_wrap(args): out = {} out["img"], out["txt"] = block(img=args["img"], - txt=args["txt"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) return out out = blocks_replace[("double_block", i)]({"img": img, - "txt": txt, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, - txt=txt, - vec=vec, - pe=pe, - attn_mask=attn_mask) + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) - if control is not None: # Controlnet + if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] @@ -151,7 +210,7 @@ def block_wrap(args): # Will calculate influence of all pulid nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) - & (timesteps >= node_data['sigma_end'])): + & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) ca_idx += 1 @@ -165,26 +224,26 @@ def block_wrap(args): def block_wrap(args): out = {} out["img"] = block(args["img"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) return out out = blocks_replace[("single_block", i)]({"img": img, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) img = out["img"] else: img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) - if control is not None: # Controlnet + if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: - img[:, txt.shape[1] :, ...] += add + img[:, txt.shape[1]:, ...] += add # PuLID attention if getattr(self, "pulid_data", {}): @@ -193,18 +252,20 @@ def block_wrap(args): # Will calculate influence of all nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) - & (timesteps >= node_data['sigma_end'])): - real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) + & (timesteps >= node_data['sigma_end'])): + real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], + real_img) ca_idx += 1 img = torch.cat((txt, real_img), 1) - img = img[:, txt.shape[1] :, ...] + img = img[:, txt.shape[1]:, ...] self.previous_residual = img.to(cache_device) - ori_img img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - + return img + def teacache_hidream_forward( self, x: torch.Tensor, @@ -213,8 +274,8 @@ def teacache_hidream_forward( context: Optional[torch.Tensor] = None, encoder_hidden_states_llama3=None, image_cond=None, - control = None, - transformer_options = {}, + control=None, + transformer_options={}, ) -> torch.Tensor: rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") @@ -252,7 +313,6 @@ def teacache_hidream_forward( img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) hidden_states = self.x_embedder(hidden_states) - # T5_encoder_hidden_states = encoder_hidden_states[0] encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] @@ -278,29 +338,34 @@ def teacache_hidream_forward( # enable teacache modulated_inp = timesteps.to(cache_device) if "full" in model_type else hidden_states.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } + _tc_ensure_state(self, keys=(0, 1)) def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + cache['accumulated_rel_l1_distance'] += poly1d( + coefficients, + ((modulated_inp - cache['previous_modulated_input']).abs().mean() + / cache['previous_modulated_input'].abs().mean()) + ) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 - except: + except Exception: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp - + b = int(len(hidden_states) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): + # Reset cache entry if chunk shape changed + chunk_sig = _tc_shape_sig(hidden_states[i*b:(i+1)*b]) + if self.teacache_state[k].get('shape_sig', None) is not None and self.teacache_state[k]['shape_sig'] != chunk_sig: + _tc_reset_cache_entry(self.teacache_state[k]) + self.teacache_state[k]['shape_sig'] = chunk_sig update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: @@ -311,9 +376,16 @@ def update_cache_state(cache, modulated_inp): should_calc = True if not should_calc: + applied_all = True for i, k in enumerate(cond_or_uncond): - hidden_states[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(hidden_states.device) - else: + applied_all = applied_all and _tc_safe_add_residual( + hidden_states[i*b:(i+1)*b], + self.teacache_state[k].get('previous_residual', None) + ) + if not applied_all: + should_calc = True + + if should_calc: # 2. Blocks ori_hidden_states = hidden_states.to(cache_device) block_id = 0 @@ -321,13 +393,14 @@ def update_cache_state(cache, modulated_inp): initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] for bid, block in enumerate(self.double_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] - cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], + dim=1) hidden_states, initial_encoder_hidden_states = block( - image_tokens = hidden_states, - image_tokens_masks = image_tokens_masks, - text_tokens = cur_encoder_hidden_states, - adaln_input = adaln_input, - rope = rope, + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=cur_encoder_hidden_states, + adaln_input=adaln_input, + rope=rope, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -356,13 +429,15 @@ def update_cache_state(cache, modulated_inp): block_id += 1 hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + resid_all = (hidden_states.to(cache_device) - ori_hidden_states) for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (hidden_states.to(cache_device) - ori_hidden_states)[i*b:(i+1)*b] + self.teacache_state[k]['previous_residual'] = resid_all[i*b:(i+1)*b] output = self.final_layer(hidden_states, adaln_input) output = self.unpatchify(output, img_sizes) return -output[:, :, :h, :w] + def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") @@ -375,11 +450,11 @@ def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_m cap_mask = attention_mask bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - + t = self.t_embedder(t, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) x_is_tensor = isinstance(x, torch.Tensor) x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) @@ -387,22 +462,22 @@ def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_m # enable teacache modulated_inp = t.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } + _tc_ensure_state(self, keys=(0, 1)) def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + cache['accumulated_rel_l1_distance'] += poly1d( + coefficients, + ((modulated_inp - cache['previous_modulated_input']).abs().mean() + / cache['previous_modulated_input'].abs().mean()) + ) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 - except: + except Exception: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp @@ -410,6 +485,10 @@ def update_cache_state(cache, modulated_inp): b = int(len(x) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): + chunk_sig = _tc_shape_sig(x[i*b:(i+1)*b]) + if self.teacache_state[k].get('shape_sig', None) is not None and self.teacache_state[k]['shape_sig'] != chunk_sig: + _tc_reset_cache_entry(self.teacache_state[k]) + self.teacache_state[k]['shape_sig'] = chunk_sig update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: @@ -420,20 +499,29 @@ def update_cache_state(cache, modulated_inp): should_calc = True if not should_calc: + applied_all = True for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: + applied_all = applied_all and _tc_safe_add_residual( + x[i*b:(i+1)*b], + self.teacache_state[k].get('previous_residual', None) + ) + if not applied_all: + should_calc = True + + if should_calc: ori_x = x.to(cache_device) # 2. Blocks for layer in self.layers: x = layer(x, mask, freqs_cis, adaln_input) + resid_all = (x.to(cache_device) - ori_x) for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] - + self.teacache_state[k]['previous_residual'] = resid_all[i*b:(i+1)*b] + x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] + + return -x - return -x def teacache_hunyuanvideo_forward( self, @@ -513,13 +601,17 @@ def teacache_hunyuanvideo_forward( self.accumulated_rel_l1_distance = 0 else: try: - self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) + self.accumulated_rel_l1_distance += poly1d( + coefficients, + ((modulated_inp - self.previous_modulated_input).abs().mean() + / self.previous_modulated_input.abs().mean()) + ) if self.accumulated_rel_l1_distance < rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 - except: + except Exception: should_calc = True self.accumulated_rel_l1_distance = 0 @@ -529,23 +621,30 @@ def teacache_hunyuanvideo_forward( should_calc = True if not should_calc: - img += self.previous_residual.to(img.device) + _tc_safe_add_residual(img, getattr(self, 'previous_residual', None)) else: ori_img = img.to(cache_device) for i, block in enumerate(self.double_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], + attn_mask=args["attention_mask"], + modulation_dims_img=args["modulation_dims_img"], + modulation_dims_txt=args["modulation_dims_txt"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]( + {"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, + 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, + {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, + modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) - if control is not None: # Controlnet + if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] @@ -558,15 +657,18 @@ def block_wrap(args): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], + modulation_dims=args["modulation_dims"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]( + {"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, + {"original_block": block_wrap}) img = out["img"] else: img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) - if control is not None: # Controlnet + if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] @@ -578,7 +680,7 @@ def block_wrap(args): if ref_latent is not None: img = img[:, ref_latent.shape[1]:] - + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] @@ -589,6 +691,7 @@ def block_wrap(args): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img + def teacache_ltxvmodel_forward( self, x, @@ -626,9 +729,11 @@ def teacache_ltxvmodel_forward( timestep = timestep * 1000.0 if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max + attention_mask = (attention_mask - 1).to(x.dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) + # FIXED: fractional_coords (was fractional_cords) + pe = LTXBaseModel.precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) batch_size = x.shape[0] timestep, embedded_timestep = self.adaln_single( @@ -637,19 +742,13 @@ def teacache_ltxvmodel_forward( batch_size=batch_size, hidden_dtype=x.dtype, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - # 2. Blocks if self.caption_projection is not None: batch_size = x.shape[0] context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + context = context.view(batch_size, -1, x.shape[-1]) blocks_replace = patches_replace.get("dit", {}) @@ -657,34 +756,39 @@ def teacache_ltxvmodel_forward( inp = x.to(cache_device) timestep_ = timestep.to(cache_device) num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] - ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1) + ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape( + batch_size, timestep_.size(1), num_ada_params, -1) shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2) modulated_inp = comfy.ldm.common_dit.rms_norm(inp) modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } + _tc_ensure_state(self, keys=(0, 1)) def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + cache['accumulated_rel_l1_distance'] += poly1d( + coefficients, + ((modulated_inp - cache['previous_modulated_input']).abs().mean() + / cache['previous_modulated_input'].abs().mean()) + ) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 - except: + except Exception: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp b = int(len(x) / len(cond_or_uncond)) - + for i, k in enumerate(cond_or_uncond): + chunk_sig = _tc_shape_sig(x[i*b:(i+1)*b]) + if self.teacache_state[k].get('shape_sig', None) is not None and self.teacache_state[k]['shape_sig'] != chunk_sig: + _tc_reset_cache_entry(self.teacache_state[k]) + self.teacache_state[k]['shape_sig'] = chunk_sig update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: @@ -693,20 +797,30 @@ def update_cache_state(cache, modulated_inp): should_calc = (should_calc or self.teacache_state[k]['should_calc']) else: should_calc = True - + if not should_calc: + applied_all = True for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: + applied_all = applied_all and _tc_safe_add_residual( + x[i*b:(i+1)*b], + self.teacache_state[k].get('previous_residual', None) + ) + if not applied_all: + should_calc = True + + if should_calc: ori_x = x.to(cache_device) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], + timestep=args["vec"], pe=args["pe"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]( + {"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, + {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -717,16 +831,16 @@ def block_wrap(args): pe=pe ) - # 3. Output scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) - # Modulation x = x * (1 + scale) + shift + + resid_all = (x.to(cache_device) - ori_x) for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] + self.teacache_state[k]['previous_residual'] = resid_all[i*b:(i+1)*b] x = self.proj_out(x) @@ -740,6 +854,7 @@ def block_wrap(args): return x + def teacache_wanmodel_forward( self, x, @@ -782,30 +897,43 @@ def teacache_wanmodel_forward( # enable teacache modulated_inp = e0.to(cache_device) if "ret_mode" in model_type else e.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } - def update_cache_state(cache, modulated_inp): + # NEW: init safe state + _tc_ensure_state(self, keys=(0, 1)) + + def update_cache_state(cache, mod_inp, current_shape_sig): + # Reset this cache entry if shape changed since last time + if cache.get('shape_sig', None) is not None and cache['shape_sig'] != current_shape_sig: + _tc_reset_cache_entry(cache) + + cache['shape_sig'] = current_shape_sig + if cache['previous_modulated_input'] is not None: try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + denom = cache['previous_modulated_input'].abs().mean() + if denom == 0: + raise ZeroDivisionError + cache['accumulated_rel_l1_distance'] += poly1d( + coefficients, + ((mod_inp - cache['previous_modulated_input']).abs().mean() / denom) + ) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 - except: + except Exception: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 - cache['previous_modulated_input'] = modulated_inp - + + cache['previous_modulated_input'] = mod_inp + b = int(len(x) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): - update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) + chunk = x[i*b:(i+1)*b] + shape_sig = _tc_shape_sig(chunk) + update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b], shape_sig) if enable_teacache: should_calc = False @@ -815,22 +943,35 @@ def update_cache_state(cache, modulated_inp): should_calc = True if not should_calc: + applied_all = True for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: + chunk = x[i*b:(i+1)*b] + ok = _tc_safe_add_residual(chunk, self.teacache_state[k].get('previous_residual', None)) + applied_all = applied_all and ok + if not applied_all: + should_calc = True + + if should_calc: ori_x = x.to(cache_device) for i, block in enumerate(self.blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], + context_img_len=context_img_len) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) + + out = blocks_replace[("double_block", i)]( + {"img": x, "txt": context, "vec": e0, "pe": freqs}, + {"original_block": block_wrap, "transformer_options": transformer_options} + ) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + resid_all = (x.to(cache_device) - ori_x) for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] + self.teacache_state[k]['previous_residual'] = resid_all[i*b:(i+1)*b] # head x = self.head(x, e) @@ -839,6 +980,7 @@ def block_wrap(args): x = self.unpatchify(x, grid_sizes) return x + class TeaCache: @classmethod def INPUT_TYPES(s): @@ -852,13 +994,13 @@ def INPUT_TYPES(s): "cache_device": (["cuda", "cpu"], {"default": "cuda", "tooltip": "Device where the cache will reside."}), } } - + RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "apply_teacache" CATEGORY = "TeaCache" TITLE = "TeaCache" - + def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_percent: float, end_percent: float, cache_device: str): if rel_l1_thresh == 0: return (model,) @@ -870,7 +1012,7 @@ def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_per new_model.model_options["transformer_options"]["coefficients"] = SUPPORTED_MODELS_COEFFICIENTS[model_type] new_model.model_options["transformer_options"]["model_type"] = model_type new_model.model_options["transformer_options"]["cache_device"] = mm.get_torch_device() if cache_device == "cuda" else torch.device("cpu") - + diffusion_model = new_model.get_model_object("diffusion_model") if "flux" in model_type: @@ -911,7 +1053,7 @@ def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_per ) else: raise ValueError(f"Unknown type {model_type}") - + def unet_wrapper_function(model_function, kwargs): input = kwargs["input"] timestep = kwargs["timestep"] @@ -928,34 +1070,34 @@ def unet_wrapper_function(model_function, kwargs): if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: current_step_index = i break - + + # NEW: Always reset cache at the start of a sampling run. + # This prevents stale cache when resolution/frames/batch changes between prompts. if current_step_index == 0: - if is_cfg: - # uncond -> 1, cond -> 0 - if hasattr(diffusion_model, 'teacache_state') and \ - diffusion_model.teacache_state[0]['previous_modulated_input'] is not None and \ - diffusion_model.teacache_state[1]['previous_modulated_input'] is not None: - delattr(diffusion_model, 'teacache_state') - else: - if hasattr(diffusion_model, 'teacache_state'): - delattr(diffusion_model, 'teacache_state') - if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): - delattr(diffusion_model, 'accumulated_rel_l1_distance') - + if hasattr(diffusion_model, 'teacache_state'): + delattr(diffusion_model, 'teacache_state') + if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): + delattr(diffusion_model, 'accumulated_rel_l1_distance') + if hasattr(diffusion_model, 'previous_modulated_input'): + delattr(diffusion_model, 'previous_modulated_input') + if hasattr(diffusion_model, 'previous_residual'): + delattr(diffusion_model, 'previous_residual') + current_percent = current_step_index / (len(sigmas) - 1) c["transformer_options"]["current_percent"] = current_percent if start_percent <= current_percent <= end_percent: c["transformer_options"]["enable_teacache"] = True else: c["transformer_options"]["enable_teacache"] = False - + with context: return model_function(input, timestep, **c) new_model.set_model_unet_function_wrapper(unet_wrapper_function) return (new_model,) - + + def patch_optimized_module(): try: from torch._dynamo.eval_frame import OptimizedModule @@ -999,6 +1141,7 @@ def __instancecheck__(cls, instance): OptimizedModule.__instancecheck__ = __instancecheck__ OptimizedModule._patched = True + def patch_same_meta(): try: from torch._inductor.fx_passes import post_grad @@ -1021,6 +1164,7 @@ def new_same_meta(a, b): post_grad.same_meta = new_same_meta new_same_meta._patched = True + class CompileModel: @classmethod def INPUT_TYPES(s): @@ -1028,37 +1172,37 @@ def INPUT_TYPES(s): "required": { "model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), - "backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), + "backend": (["inductor", "cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), } } - + RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "apply_compile" CATEGORY = "TeaCache" TITLE = "Compile Model" - + def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool): patch_optimized_module() patch_same_meta() torch._dynamo.config.suppress_errors = True - + new_model = model.clone() new_model.add_object_patch( - "diffusion_model", - torch.compile( - new_model.get_model_object("diffusion_model"), - mode=mode, - backend=backend, - fullgraph=fullgraph, - dynamic=dynamic - ) - ) - + "diffusion_model", + torch.compile( + new_model.get_model_object("diffusion_model"), + mode=mode, + backend=backend, + fullgraph=fullgraph, + dynamic=dynamic + ) + ) + return (new_model,) - + NODE_CLASS_MAPPINGS = { "TeaCache": TeaCache,