From a9929fb14cb45116656436d6b99285beab7706f1 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 15 Apr 2026 19:40:05 +0800 Subject: [PATCH 1/5] Support padded FA and add TensorBoard logging Signed-off-by: Yu Feng --- tests/test_flex_attention.py | 177 +++++++++- torchspec/controller/loop.py | 8 +- torchspec/models/draft/llama3_eagle.py | 425 ++++++++++++++++++++++--- torchspec/utils/logging.py | 23 ++ 4 files changed, 585 insertions(+), 48 deletions(-) diff --git a/tests/test_flex_attention.py b/tests/test_flex_attention.py index b7c70bea..247838a5 100644 --- a/tests/test_flex_attention.py +++ b/tests/test_flex_attention.py @@ -4,9 +4,14 @@ import torch._dynamo as dynamo from transformers import LlamaConfig +import torchspec.models.draft.llama3_eagle as llama_mod from tests.utils import norm_tensor from torchspec.models.draft.base import prepare_decoder_attention_mask -from torchspec.models.draft.llama3_eagle import LlamaAttention, LlamaFlexAttention +from torchspec.models.draft.llama3_eagle import ( + LlamaAttention, + LlamaFlashAttention, + LlamaFlexAttention, +) from torchspec.models.ops.flex_attention import ( compile_friendly_create_block_mask, compile_friendly_flex_attention, @@ -18,6 +23,11 @@ TTT_LENGTH = 7 torch.manual_seed(0) +try: + from flash_attn import flash_attn_func as standard_flash_attn_func +except ImportError: + standard_flash_attn_func = None + class TestFlexAttention(unittest.TestCase): def setUp(self): @@ -271,5 +281,170 @@ def test_eagle3_flex_mask(self): compile_friendly_flex_attention(query, key_cache, value_cache, block_mask=block_mask) +@unittest.skipUnless(standard_flash_attn_func is not None, "flash_attn not installed") +class TestFlashAttentionCachedPath(unittest.TestCase): + def test_cached_path_gradients_match_flex_attention(self): + cfg = LlamaConfig( + hidden_size=128, + num_attention_heads=8, + num_key_value_heads=2, + max_position_embeddings=4096, + rms_norm_eps=1e-05, + vocab_size=32000, + intermediate_size=688, + hidden_act="silu", + num_hidden_layers=1, + torch_dtype="bfloat16", + ) + dtype = torch.bfloat16 + seq_len = 128 + batch_size = 2 + hidden_size = cfg.hidden_size * 2 + position_ids = torch.arange(seq_len, device="cuda").unsqueeze(0).repeat(batch_size, 1) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device="cuda") + attention_mask[1, 96:] = False + + old_std_fwd = llama_mod._std_flash_attn_forward + old_std_bwd = llama_mod._std_flash_attn_backward + old_std_mod = llama_mod._std_flash_attn_mod + llama_mod._std_flash_attn_mod = None + + def _standard_flash_attn_forward_wrapper(*args, **kwargs): + kwargs.pop("window_size_left", None) + kwargs.pop("window_size_right", None) + kwargs.pop("return_softmax", None) + out, lse, _ = standard_flash_attn_func( + *args, + return_attn_probs=True, + **kwargs, + ) + return out, lse, None, None + + def _standard_flash_attn_backward_wrapper( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + alibi_slopes, + deterministic, + rng_state, + ): + del dropout_p, window_size_left, window_size_right + del softcap, alibi_slopes, deterministic, rng_state + qh = q.permute(0, 2, 1, 3).float() + kh = k.permute(0, 2, 1, 3).float() + vh = v.permute(0, 2, 1, 3).float() + if kh.shape[1] != qh.shape[1]: + repeat = qh.shape[1] // kh.shape[1] + kh = kh.repeat_interleave(repeat, dim=1) + vh = vh.repeat_interleave(repeat, dim=1) + do = dout.permute(0, 2, 1, 3).float() + oh = out.permute(0, 2, 1, 3).float() + lse = softmax_lse.float() + + scores = torch.matmul(qh, kh.transpose(-1, -2)) * softmax_scale + if causal: + q_len = scores.shape[-2] + k_len = scores.shape[-1] + mask = torch.triu( + torch.ones(q_len, k_len, device=scores.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(mask, float("-inf")) + + probs = torch.exp(scores - lse.unsqueeze(-1)) + d_v = torch.matmul(probs.transpose(-1, -2), do) + d_p = torch.matmul(do, vh.transpose(-1, -2)) + row_dot = (do * oh).sum(dim=-1, keepdim=True) + d_s = probs * (d_p - row_dot) + d_q = torch.matmul(d_s, kh) * softmax_scale + d_k = torch.matmul(d_s.transpose(-1, -2), qh) * softmax_scale + if dk.shape[2] != d_k.shape[1]: + kv_heads = dk.shape[2] + repeat = d_k.shape[1] // kv_heads + d_k = d_k.view(d_k.shape[0], kv_heads, repeat, d_k.shape[2], d_k.shape[3]).sum( + dim=2 + ) + d_v = d_v.view(d_v.shape[0], kv_heads, repeat, d_v.shape[2], d_v.shape[3]).sum( + dim=2 + ) + + dq.copy_(d_q.permute(0, 2, 1, 3).to(dq.dtype)) + dk.copy_(d_k.permute(0, 2, 1, 3).to(dk.dtype)) + dv.copy_(d_v.permute(0, 2, 1, 3).to(dv.dtype)) + + llama_mod._std_flash_attn_forward = _standard_flash_attn_forward_wrapper + llama_mod._std_flash_attn_backward = _standard_flash_attn_backward_wrapper + + try: + flex_attention = LlamaFlexAttention(cfg).to("cuda").to(dtype) + flash_attention = LlamaFlashAttention(cfg).to("cuda").to(dtype) + with torch.no_grad(): + for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"): + getattr(flash_attention, proj_name).weight.copy_( + getattr(flex_attention, proj_name).weight + ) + + loss_mask = attention_mask.to(dtype) + flex_cache_keys = flex_cache_values = None + flash_cache_keys = flash_cache_values = None + hidden_states_list = [ + norm_tensor((batch_size, seq_len, hidden_size), device="cuda", dtype=dtype) + for _ in range(2) + ] + flex_losses = [] + flash_losses = [] + + for idx in range(2): + hidden_states = hidden_states_list[idx] + flex_out, flex_cache_keys, flex_cache_values = flex_attention( + hidden_states=hidden_states.clone(), + attention_mask=attention_mask, + position_ids=position_ids, + cache_keys=flex_cache_keys, + cache_values=flex_cache_values, + use_cache=True, + ) + flash_out, flash_cache_keys, flash_cache_values = flash_attention( + hidden_states=hidden_states.clone(), + attention_mask=attention_mask, + position_ids=position_ids, + cache_keys=flash_cache_keys, + cache_values=flash_cache_values, + use_cache=True, + ) + flex_losses.append((flex_out * loss_mask[..., None]).sum().mean()) + flash_losses.append((flash_out * loss_mask[..., None]).sum().mean()) + if idx == 0: + loss_mask = torch.nn.functional.pad(loss_mask[:, 1:], (0, 1)) + + (sum(flex_losses) / 2).backward() + (sum(flash_losses) / 2).backward() + + for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"): + torch.testing.assert_close( + getattr(flash_attention, proj_name).weight.grad, + getattr(flex_attention, proj_name).weight.grad, + atol=5e-2, + rtol=1e-2, + msg=f"{proj_name} grad mismatch on cached path", + ) + finally: + llama_mod._std_flash_attn_forward = old_std_fwd + llama_mod._std_flash_attn_backward = old_std_bwd + llama_mod._std_flash_attn_mod = old_std_mod + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/torchspec/controller/loop.py b/torchspec/controller/loop.py index b4278b02..aefb6b0f 100644 --- a/torchspec/controller/loop.py +++ b/torchspec/controller/loop.py @@ -36,7 +36,7 @@ setup_eval, update_checkpoint_eval_meta, ) -from torchspec.utils.logging import logger +from torchspec.utils.logging import get_tb_writer, logger def _maybe_sync_draft_weights(args, completed_steps, train_group, inference_engines): @@ -349,6 +349,12 @@ def training_loop( if getattr(wandb, "run", None) is not None: wandb.log(metrics) + tb_writer = get_tb_writer() + if tb_writer is not None: + for key, value in metrics.items(): + if isinstance(value, (int, float)): + tb_writer.add_scalar(key, value, completed_steps) + # ── Eval at explicit interval (if configured) ───────── # Skip if a checkpoint save is about to run (it will eval anyway) save_due = _is_save_interval_step(completed_steps, args.save_interval) diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index 0e3e56dc..e4d77d0d 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -68,6 +68,48 @@ _flash_attn_fwd = None _flash_attn_bwd = None + +def _import_standard_flash_attn(): + try: + import flash_attn as mod + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import ( + _flash_attn_backward as backward, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_forward as forward, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward as varlen_backward, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward as varlen_forward, + ) + except ImportError as exc: + return None, None, None, None, None, None, None, exc + + return mod, pad_input, unpad_input, forward, backward, varlen_forward, varlen_backward, None + + +( + _std_flash_attn_mod, + _std_flash_pad_input, + _std_flash_unpad_input, + _std_flash_attn_forward, + _std_flash_attn_backward, + _std_flash_attn_varlen_forward, + _std_flash_attn_varlen_backward, + _std_flash_attn_import_error, +) = _import_standard_flash_attn() + + +def _raise_standard_flash_attn_unavailable() -> None: + raise RuntimeError( + "LlamaFlashAttention requires the standard flash-attn interface " + f"(import error: {_std_flash_attn_import_error!r})" + ) + + try: import cutlass import cutlass.cute as cute @@ -1400,6 +1442,17 @@ class LlamaFlashAttention(LlamaAttention): - cache_keys/cache_values: tensor caches for storing past key and value states """ + def __init__(self, config): + super().__init__(config) + if ( + _std_flash_attn_forward is None + or _std_flash_attn_backward is None + or _std_flash_unpad_input is None + or _std_flash_pad_input is None + or _std_flash_attn_varlen_backward is None + ): + _raise_standard_flash_attn_unavailable() + def forward( self, hidden_states: torch.Tensor, @@ -1415,12 +1468,10 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # FA uses [bsz, seq_len, heads, head_dim] layout query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - # cache_keys shape: [bsz, num_cached, seq_len, num_kv_heads, head_dim] lck = 0 if cache_keys is None else cache_keys.shape[1] if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): cos, sin = self.rotary_emb(query_states, position_ids + lck) @@ -1440,7 +1491,6 @@ def forward( query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 ) - # Append to tensor cache: [bsz, num_cached, seq_len, num_kv_heads, head_dim] if cache_keys is not None: cache_keys = torch.cat([cache_keys, key_states.unsqueeze(1)], dim=1) cache_values = torch.cat([cache_values, value_states.unsqueeze(1)], dim=1) @@ -1448,54 +1498,20 @@ def forward( cache_keys = key_states.unsqueeze(1) cache_values = value_states.unsqueeze(1) - k0 = cache_keys[:, 0] - v0 = cache_values[:, 0] + assert attention_mask is not None, "LlamaFlashAttention cached path requires attention_mask" + valid_lengths = attention_mask.sum(dim=-1, dtype=torch.long) - lck + valid_lengths = valid_lengths.clamp_(0, q_len) - assert _flash_attn_func is not None, ( - f"flash_attn.cute is unavailable. ImportError: {_flash_attn_import_error!r}" - ) - attn_output, lse = _flash_attn_func( + attn_output = _FlashCachedMergeFunc.apply( query_states, - k0, - v0, - softmax_scale=1.0 / math.sqrt(self.head_dim), - causal=True, + cache_keys, + cache_values, + valid_lengths, + 1.0 / math.sqrt(self.head_dim), ) - # Accumulate O in FP32 so the backward delta path (rowsum(dO·O)) stays accurate - attn_output = attn_output.float() - lse = lse.transpose(1, 2) - - lck = cache_keys.shape[1] - if lck > 1: - q_shape_expanded = ( - bsz, - q_len, - self.num_key_value_heads, - self.num_key_value_groups, - self.head_dim, - ) - attn_outputs = [attn_output.view(q_shape_expanded)] - lses = [lse.view(q_shape_expanded[:-1])] - - for i in range(1, lck): - ki = cache_keys[:, i].unsqueeze(-2) - qi = query_states.view(q_shape_expanded) - vi = cache_values[:, i].unsqueeze(-2) - - attn_outputs.append(vi.float()) - lses.append((qi.float() * ki.float()).sum(-1) / math.sqrt(self.head_dim)) - - lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1) - attn_output = sum( - attn_outputi * torch.exp(lsei - lse).unsqueeze(-1) - for attn_outputi, lsei in zip(attn_outputs, lses) - ) - attn_output = attn_output.to(self.o_proj.weight.dtype) attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) - attn_output = self.o_proj(attn_output) - return attn_output, cache_keys, cache_values @@ -1621,6 +1637,323 @@ def forward( return attn_output, cache_keys, cache_values +def _standard_flash_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: float, + causal: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": + out, _, _, _, _, lse, _, _ = _std_flash_attn_forward( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) + else: + out, lse, _, _ = _std_flash_attn_forward( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) + return out, lse + + +def _standard_flash_attn_backward_call( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + softmax_scale: float, + causal: bool, +) -> None: + if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": + _std_flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + 0.0, + softmax_scale, + causal, + (-1, -1), + 0.0, + None, + False, + None, + ) + else: + _std_flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + 0.0, + softmax_scale, + causal, + -1, + -1, + 0.0, + None, + False, + None, + ) + + +def _standard_flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: torch.Tensor, + softmax_scale: float, + causal: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + from flash_attn import flash_attn_varlen_func + + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = _std_flash_unpad_input(q, attention_mask) + k_unpad, _, cu_seqlens_k, max_seqlen_k, _ = _std_flash_unpad_input(k, attention_mask) + v_unpad, _, _, _, _ = _std_flash_unpad_input(v, attention_mask) + out_unpad, lse_unpad, _ = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=True, + ) + out = _std_flash_pad_input(out_unpad, indices_q, q.shape[0], q.shape[1]) + lse_padded = _std_flash_pad_input( + lse_unpad.transpose(0, 1), indices_q, q.shape[0], q.shape[1] + ).transpose(1, 2) + return out, lse_padded, indices_q, cu_seqlens_q + + +def _standard_flash_attn_varlen_backward_call( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + attention_mask: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + softmax_scale: float, + causal: bool, +) -> None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = _std_flash_unpad_input(q, attention_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = _std_flash_unpad_input(k, attention_mask) + v_unpad, _, _, _, _ = _std_flash_unpad_input(v, attention_mask) + dout_unpad, _, _, _, _ = _std_flash_unpad_input(dout, attention_mask) + out_unpad, _, _, _, _ = _std_flash_unpad_input(out, attention_mask) + lse_unpad, _, _, _, _ = _std_flash_unpad_input(softmax_lse.transpose(1, 2), attention_mask) + lse_unpad = lse_unpad.transpose(0, 1).contiguous() + + dq_unpad = torch.empty_like(q_unpad) + dk_unpad = torch.empty_like(k_unpad) + dv_unpad = torch.empty_like(v_unpad) + if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": + _std_flash_attn_varlen_backward( + dout_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + lse_unpad, + dq_unpad, + dk_unpad, + dv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale, + causal, + (-1, -1), + 0.0, + None, + False, + None, + ) + else: + _std_flash_attn_varlen_backward( + dout_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + lse_unpad, + dq_unpad, + dk_unpad, + dv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale, + causal, + -1, + -1, + 0.0, + None, + False, + None, + ) + dq.copy_(_std_flash_pad_input(dq_unpad, indices_q, q.shape[0], q.shape[1])) + dk.copy_(_std_flash_pad_input(dk_unpad, indices_k, k.shape[0], k.shape[1])) + dv.copy_(_std_flash_pad_input(dv_unpad, indices_k, v.shape[0], v.shape[1])) + + +class _FlashCachedMergeFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, cache_k, cache_v, valid_lengths, softmax_scale: float): + bsz, q_len, num_heads, head_dim = q.shape + num_blocks = cache_k.shape[1] + num_kv_heads = cache_k.shape[3] + num_groups = num_heads // num_kv_heads + q_expanded = q.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + valid_lengths = valid_lengths.to(device=q.device, dtype=torch.long).clamp_(0, q_len) + valid_mask = ( + torch.arange(q_len, device=q.device).unsqueeze(0) < valid_lengths.unsqueeze(1) + ).view(bsz, q_len, 1, 1) + attention_mask = valid_mask.view(bsz, q_len) + + k0 = cache_k[:, 0].contiguous() + v0 = cache_v[:, 0].contiguous() + out0, lse0_kernel, _, _ = _standard_flash_attn_varlen_forward( + q.contiguous(), + k0, + v0, + attention_mask, + softmax_scale=softmax_scale, + causal=True, + ) + out0_expanded = out0.view(bsz, q_len, num_kv_heads, num_groups, head_dim).float() + neg_inf = torch.tensor(float("-inf"), device=q.device, dtype=torch.float32) + lse0 = lse0_kernel.transpose(1, 2).view(bsz, q_len, num_kv_heads, num_groups).float() + lse0 = torch.where(valid_mask, lse0, neg_inf) + lse_terms = [lse0] + attn_terms = [out0_expanded] + for i in range(1, num_blocks): + ki = cache_k[:, i].unsqueeze(-2).float() + vi = cache_v[:, i].unsqueeze(-2).float() + lse_i = (q_expanded.float() * ki).sum(-1) * softmax_scale + lse_terms.append(torch.where(valid_mask, lse_i, neg_inf)) + attn_terms.append(vi.expand_as(out0_expanded)) + + merged_lse = torch.logsumexp(torch.stack(lse_terms, dim=-1), dim=-1) + out = sum( + term * torch.exp(lse - merged_lse).unsqueeze(-1) + for term, lse in zip(attn_terms, lse_terms) + ) + out = torch.where(valid_mask.unsqueeze(-1), out, 0.0) + merged_lse = torch.where(valid_mask, merged_lse, 0.0) + ctx.save_for_backward(q, cache_k, cache_v, out, merged_lse, valid_lengths) + ctx.softmax_scale = softmax_scale + return out.to(q.dtype).reshape_as(q) + + @staticmethod + def backward(ctx, grad_out): + q, cache_k, cache_v, out, merged_lse, valid_lengths = ctx.saved_tensors + bsz, q_len, num_heads, head_dim = q.shape + num_blocks = cache_k.shape[1] + num_kv_heads = cache_k.shape[3] + num_groups = num_heads // num_kv_heads + scale = ctx.softmax_scale + if grad_out.ndim == 3: + grad_out = grad_out.view(bsz, q_len, num_heads, head_dim) + valid_lengths = valid_lengths.to(device=q.device, dtype=torch.long) + valid_mask = ( + torch.arange(q_len, device=q.device).unsqueeze(0) < valid_lengths.unsqueeze(1) + ).view(bsz, q_len, 1, 1) + attention_mask = valid_mask.view(bsz, q_len) + grad_out = torch.where(valid_mask, grad_out, 0.0) + + grad_out_f = grad_out.float().view(bsz, q_len, num_kv_heads, num_groups, head_dim) + q_f = q.float() + q_expanded = q_f.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + out_f = out.float() + out_expanded = out_f.view(bsz, q_len, num_kv_heads, num_groups, head_dim) + out_q = out.to(q.dtype).reshape(bsz, q_len, num_heads, head_dim) + + dq = torch.zeros_like(q_f) + dcache_k = torch.zeros_like(cache_k.float()) + dcache_v = torch.zeros_like(cache_v.float()) + + dq0 = torch.zeros_like(q) + dk0 = torch.zeros_like(cache_k[:, 0]) + dv0 = torch.zeros_like(cache_v[:, 0]) + merged_lse_kernel = merged_lse.reshape(bsz, q_len, num_heads).transpose(1, 2).contiguous() + _standard_flash_attn_varlen_backward_call( + grad_out.contiguous(), + q.contiguous(), + cache_k[:, 0].contiguous(), + cache_v[:, 0].contiguous(), + out_q.contiguous(), + merged_lse_kernel, + attention_mask, + dq0, + dk0, + dv0, + softmax_scale=scale, + causal=True, + ) + dq += dq0.float() + dcache_k[:, 0] += dk0.float() + dcache_v[:, 0] += dv0.float() + + for i in range(1, num_blocks): + ki = cache_k[:, i].float().unsqueeze(-2) + vi = cache_v[:, i].float().unsqueeze(-2) + lse_i = (q_expanded * ki).sum(-1) * scale + wi = torch.where(valid_mask, torch.exp(lse_i - merged_lse), 0.0) + d_out_i = grad_out_f * wi.unsqueeze(-1) + d_lse_i = wi * (grad_out_f * (vi.expand_as(out_expanded) - out_expanded)).sum(-1) + dq += (d_lse_i.unsqueeze(-1) * scale * ki).reshape_as(q) + dcache_k[:, i] += (d_lse_i.unsqueeze(-1) * scale * q_expanded).sum(dim=3) + dcache_v[:, i] += d_out_i.sum(dim=3) + + return dq.to(q.dtype), dcache_k.to(cache_k.dtype), dcache_v.to(cache_v.dtype), None, None + + def warmup_flash_attention_masked( q_len: int, num_heads: int, @@ -1769,7 +2102,7 @@ def __init__(self, config, attention_backend: str = "sdpa"): self.self_attn = LlamaFlexAttention(config=config) elif attention_backend == "fa4": self.self_attn = LlamaFlashAttentionMasked(config=config) - elif attention_backend == "fa_low_acc": + elif attention_backend in ("fa", "fa_low_acc"): self.self_attn = LlamaFlashAttention(config=config) else: raise ValueError(f"Unknown attention backend {attention_backend}") diff --git a/torchspec/utils/logging.py b/torchspec/utils/logging.py index 71d91040..fc37325b 100644 --- a/torchspec/utils/logging.py +++ b/torchspec/utils/logging.py @@ -27,6 +27,7 @@ from torchspec.utils import wandb as wandb_utils _LOG_FORMAT = "[%(asctime)s] %(filename)s:%(lineno)d %(levelname)s %(message)s" +_tb_writer = None def _get_logger_level(): @@ -117,7 +118,29 @@ def print_on_rank0(message): def init_tracking(args, primary: bool = True, **kwargs): + global _tb_writer if primary: wandb_utils.init_wandb_primary(args, **kwargs) + if getattr(args, "use_tensorboard", False) and getattr(args, "output_dir", None): + from torch.utils.tensorboard import SummaryWriter + + tb_log_dir = os.path.join(args.output_dir, "runs") + os.makedirs(tb_log_dir, exist_ok=True) + _tb_writer = SummaryWriter(log_dir=tb_log_dir) + logger.info(f"TensorBoard writer initialized at {tb_log_dir}") else: wandb_utils.init_wandb_secondary(args, **kwargs) + + +def get_tb_writer(): + """Return the module-level TensorBoard SummaryWriter, or None if not initialized.""" + return _tb_writer + + +def close_tb_writer(): + """Flush and close the TensorBoard writer.""" + global _tb_writer + if _tb_writer is not None: + _tb_writer.flush() + _tb_writer.close() + _tb_writer = None From 4efdda29433affe6d426d5a509d6f4178d2a6082 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Wed, 15 Apr 2026 20:55:23 +0800 Subject: [PATCH 2/5] Remove legacy flash-attn compatibility branches Signed-off-by: Yu Feng --- torchspec/models/draft/llama3_eagle.py | 174 ++++++++----------------- 1 file changed, 57 insertions(+), 117 deletions(-) diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index e4d77d0d..a21b5d26 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -1644,33 +1644,19 @@ def _standard_flash_attn_forward( softmax_scale: float, causal: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": - out, _, _, _, _, lse, _, _ = _std_flash_attn_forward( - q, - k, - v, - dropout_p=0.0, - softmax_scale=softmax_scale, - causal=causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - return_softmax=False, - ) - else: - out, lse, _, _ = _std_flash_attn_forward( - q, - k, - v, - dropout_p=0.0, - softmax_scale=softmax_scale, - causal=causal, - window_size_left=-1, - window_size_right=-1, - softcap=0.0, - alibi_slopes=None, - return_softmax=False, - ) + out, lse, _, _ = _std_flash_attn_forward( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + ) return out, lse @@ -1687,47 +1673,26 @@ def _standard_flash_attn_backward_call( softmax_scale: float, causal: bool, ) -> None: - if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": - _std_flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - 0.0, - softmax_scale, - causal, - (-1, -1), - 0.0, - None, - False, - None, - ) - else: - _std_flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - 0.0, - softmax_scale, - causal, - -1, - -1, - 0.0, - None, - False, - None, - ) + _std_flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + 0.0, + softmax_scale, + causal, + -1, + -1, + 0.0, + None, + False, + None, + ) def _standard_flash_attn_varlen_forward( @@ -1788,55 +1753,30 @@ def _standard_flash_attn_varlen_backward_call( dq_unpad = torch.empty_like(q_unpad) dk_unpad = torch.empty_like(k_unpad) dv_unpad = torch.empty_like(v_unpad) - if _std_flash_attn_mod is not None and _std_flash_attn_mod.__version__ < "2.6.3": - _std_flash_attn_varlen_backward( - dout_unpad, - q_unpad, - k_unpad, - v_unpad, - out_unpad, - lse_unpad, - dq_unpad, - dk_unpad, - dv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - 0.0, - softmax_scale, - causal, - (-1, -1), - 0.0, - None, - False, - None, - ) - else: - _std_flash_attn_varlen_backward( - dout_unpad, - q_unpad, - k_unpad, - v_unpad, - out_unpad, - lse_unpad, - dq_unpad, - dk_unpad, - dv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - 0.0, - softmax_scale, - causal, - -1, - -1, - 0.0, - None, - False, - None, - ) + _std_flash_attn_varlen_backward( + dout_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + lse_unpad, + dq_unpad, + dk_unpad, + dv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale, + causal, + -1, + -1, + 0.0, + None, + False, + None, + ) dq.copy_(_std_flash_pad_input(dq_unpad, indices_q, q.shape[0], q.shape[1])) dk.copy_(_std_flash_pad_input(dk_unpad, indices_k, k.shape[0], k.shape[1])) dv.copy_(_std_flash_pad_input(dv_unpad, indices_k, v.shape[0], v.shape[1])) From 9753b9dd509c97e7e7848f1c078d493d24305593 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 16 Apr 2026 10:40:36 +0800 Subject: [PATCH 3/5] xxx Signed-off-by: Yu Feng --- torchspec/models/draft/llama3_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index a21b5d26..ffa2fddf 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -2042,7 +2042,7 @@ def __init__(self, config, attention_backend: str = "sdpa"): self.self_attn = LlamaFlexAttention(config=config) elif attention_backend == "fa4": self.self_attn = LlamaFlashAttentionMasked(config=config) - elif attention_backend in ("fa", "fa_low_acc"): + elif attention_backend == "fa": self.self_attn = LlamaFlashAttention(config=config) else: raise ValueError(f"Unknown attention backend {attention_backend}") From 941bc15f138016281f4d96732d07767061162cd5 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 16 Apr 2026 10:54:14 +0800 Subject: [PATCH 4/5] Refactor standard flash-attn varlen import Signed-off-by: Yu Feng --- torchspec/models/draft/llama3_eagle.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index ffa2fddf..ed92b3e9 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -72,6 +72,7 @@ def _import_standard_flash_attn(): try: import flash_attn as mod + from flash_attn import flash_attn_varlen_func as varlen_func from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import ( _flash_attn_backward as backward, @@ -86,13 +87,14 @@ def _import_standard_flash_attn(): _flash_attn_varlen_forward as varlen_forward, ) except ImportError as exc: - return None, None, None, None, None, None, None, exc + return None, None, None, None, None, None, None, None, exc - return mod, pad_input, unpad_input, forward, backward, varlen_forward, varlen_backward, None + return mod, varlen_func, pad_input, unpad_input, forward, backward, varlen_forward, varlen_backward, None ( _std_flash_attn_mod, + _std_flash_attn_varlen_func, _std_flash_pad_input, _std_flash_unpad_input, _std_flash_attn_forward, @@ -1703,12 +1705,10 @@ def _standard_flash_attn_varlen_forward( softmax_scale: float, causal: bool, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - from flash_attn import flash_attn_varlen_func - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = _std_flash_unpad_input(q, attention_mask) k_unpad, _, cu_seqlens_k, max_seqlen_k, _ = _std_flash_unpad_input(k, attention_mask) v_unpad, _, _, _, _ = _std_flash_unpad_input(v, attention_mask) - out_unpad, lse_unpad, _ = flash_attn_varlen_func( + out_unpad, lse_unpad, _ = _std_flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, From b70fd17107d5ff1ec32f3e086f1a8b122f382e99 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 16 Apr 2026 11:26:06 +0800 Subject: [PATCH 5/5] modify format Signed-off-by: Yu Feng --- torchspec/models/draft/llama3_eagle.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchspec/models/draft/llama3_eagle.py b/torchspec/models/draft/llama3_eagle.py index ed92b3e9..6200c9ce 100644 --- a/torchspec/models/draft/llama3_eagle.py +++ b/torchspec/models/draft/llama3_eagle.py @@ -89,7 +89,17 @@ def _import_standard_flash_attn(): except ImportError as exc: return None, None, None, None, None, None, None, None, exc - return mod, varlen_func, pad_input, unpad_input, forward, backward, varlen_forward, varlen_backward, None + return ( + mod, + varlen_func, + pad_input, + unpad_input, + forward, + backward, + varlen_forward, + varlen_backward, + None, + ) (