Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 176 additions & 1 deletion tests/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
import torch._dynamo as dynamo
from transformers import LlamaConfig

import torchspec.models.draft.llama3_eagle as llama_mod
from tests.utils import norm_tensor
from torchspec.models.draft.base import prepare_decoder_attention_mask
from torchspec.models.draft.llama3_eagle import LlamaAttention, LlamaFlexAttention
from torchspec.models.draft.llama3_eagle import (
LlamaAttention,
LlamaFlashAttention,
LlamaFlexAttention,
)
from torchspec.models.ops.flex_attention import (
compile_friendly_create_block_mask,
compile_friendly_flex_attention,
Expand All @@ -18,6 +23,11 @@
TTT_LENGTH = 7
torch.manual_seed(0)

try:
from flash_attn import flash_attn_func as standard_flash_attn_func
except ImportError:
standard_flash_attn_func = None


class TestFlexAttention(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -271,5 +281,170 @@ def test_eagle3_flex_mask(self):
compile_friendly_flex_attention(query, key_cache, value_cache, block_mask=block_mask)


@unittest.skipUnless(standard_flash_attn_func is not None, "flash_attn not installed")
class TestFlashAttentionCachedPath(unittest.TestCase):
def test_cached_path_gradients_match_flex_attention(self):
cfg = LlamaConfig(
hidden_size=128,
num_attention_heads=8,
num_key_value_heads=2,
max_position_embeddings=4096,
rms_norm_eps=1e-05,
vocab_size=32000,
intermediate_size=688,
hidden_act="silu",
num_hidden_layers=1,
torch_dtype="bfloat16",
)
dtype = torch.bfloat16
seq_len = 128
batch_size = 2
hidden_size = cfg.hidden_size * 2
position_ids = torch.arange(seq_len, device="cuda").unsqueeze(0).repeat(batch_size, 1)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device="cuda")
attention_mask[1, 96:] = False

old_std_fwd = llama_mod._std_flash_attn_forward
old_std_bwd = llama_mod._std_flash_attn_backward
old_std_mod = llama_mod._std_flash_attn_mod
llama_mod._std_flash_attn_mod = None

def _standard_flash_attn_forward_wrapper(*args, **kwargs):
kwargs.pop("window_size_left", None)
kwargs.pop("window_size_right", None)
kwargs.pop("return_softmax", None)
out, lse, _ = standard_flash_attn_func(
*args,
return_attn_probs=True,
**kwargs,
)
return out, lse, None, None

def _standard_flash_attn_backward_wrapper(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
softcap,
alibi_slopes,
deterministic,
rng_state,
):
del dropout_p, window_size_left, window_size_right
del softcap, alibi_slopes, deterministic, rng_state
qh = q.permute(0, 2, 1, 3).float()
kh = k.permute(0, 2, 1, 3).float()
vh = v.permute(0, 2, 1, 3).float()
if kh.shape[1] != qh.shape[1]:
repeat = qh.shape[1] // kh.shape[1]
kh = kh.repeat_interleave(repeat, dim=1)
vh = vh.repeat_interleave(repeat, dim=1)
do = dout.permute(0, 2, 1, 3).float()
oh = out.permute(0, 2, 1, 3).float()
lse = softmax_lse.float()

scores = torch.matmul(qh, kh.transpose(-1, -2)) * softmax_scale
if causal:
q_len = scores.shape[-2]
k_len = scores.shape[-1]
mask = torch.triu(
torch.ones(q_len, k_len, device=scores.device, dtype=torch.bool),
diagonal=1,
)
scores = scores.masked_fill(mask, float("-inf"))

probs = torch.exp(scores - lse.unsqueeze(-1))
d_v = torch.matmul(probs.transpose(-1, -2), do)
d_p = torch.matmul(do, vh.transpose(-1, -2))
row_dot = (do * oh).sum(dim=-1, keepdim=True)
d_s = probs * (d_p - row_dot)
d_q = torch.matmul(d_s, kh) * softmax_scale
d_k = torch.matmul(d_s.transpose(-1, -2), qh) * softmax_scale
if dk.shape[2] != d_k.shape[1]:
kv_heads = dk.shape[2]
repeat = d_k.shape[1] // kv_heads
d_k = d_k.view(d_k.shape[0], kv_heads, repeat, d_k.shape[2], d_k.shape[3]).sum(
dim=2
)
d_v = d_v.view(d_v.shape[0], kv_heads, repeat, d_v.shape[2], d_v.shape[3]).sum(
dim=2
)

dq.copy_(d_q.permute(0, 2, 1, 3).to(dq.dtype))
dk.copy_(d_k.permute(0, 2, 1, 3).to(dk.dtype))
dv.copy_(d_v.permute(0, 2, 1, 3).to(dv.dtype))

llama_mod._std_flash_attn_forward = _standard_flash_attn_forward_wrapper
llama_mod._std_flash_attn_backward = _standard_flash_attn_backward_wrapper

try:
flex_attention = LlamaFlexAttention(cfg).to("cuda").to(dtype)
flash_attention = LlamaFlashAttention(cfg).to("cuda").to(dtype)
with torch.no_grad():
for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"):
getattr(flash_attention, proj_name).weight.copy_(
getattr(flex_attention, proj_name).weight
)

loss_mask = attention_mask.to(dtype)
flex_cache_keys = flex_cache_values = None
flash_cache_keys = flash_cache_values = None
hidden_states_list = [
norm_tensor((batch_size, seq_len, hidden_size), device="cuda", dtype=dtype)
for _ in range(2)
]
flex_losses = []
flash_losses = []

for idx in range(2):
hidden_states = hidden_states_list[idx]
flex_out, flex_cache_keys, flex_cache_values = flex_attention(
hidden_states=hidden_states.clone(),
attention_mask=attention_mask,
position_ids=position_ids,
cache_keys=flex_cache_keys,
cache_values=flex_cache_values,
use_cache=True,
)
flash_out, flash_cache_keys, flash_cache_values = flash_attention(
hidden_states=hidden_states.clone(),
attention_mask=attention_mask,
position_ids=position_ids,
cache_keys=flash_cache_keys,
cache_values=flash_cache_values,
use_cache=True,
)
flex_losses.append((flex_out * loss_mask[..., None]).sum().mean())
flash_losses.append((flash_out * loss_mask[..., None]).sum().mean())
if idx == 0:
loss_mask = torch.nn.functional.pad(loss_mask[:, 1:], (0, 1))

(sum(flex_losses) / 2).backward()
(sum(flash_losses) / 2).backward()

for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"):
torch.testing.assert_close(
getattr(flash_attention, proj_name).weight.grad,
getattr(flex_attention, proj_name).weight.grad,
atol=5e-2,
rtol=1e-2,
msg=f"{proj_name} grad mismatch on cached path",
)
finally:
llama_mod._std_flash_attn_forward = old_std_fwd
llama_mod._std_flash_attn_backward = old_std_bwd
llama_mod._std_flash_attn_mod = old_std_mod


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 7 additions & 1 deletion torchspec/controller/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
setup_eval,
update_checkpoint_eval_meta,
)
from torchspec.utils.logging import logger
from torchspec.utils.logging import get_tb_writer, logger


def _maybe_sync_draft_weights(args, completed_steps, train_group, inference_engines):
Expand Down Expand Up @@ -349,6 +349,12 @@ def training_loop(
if getattr(wandb, "run", None) is not None:
wandb.log(metrics)

tb_writer = get_tb_writer()
if tb_writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
tb_writer.add_scalar(key, value, completed_steps)

# ── Eval at explicit interval (if configured) ─────────
# Skip if a checkpoint save is about to run (it will eval anyway)
save_due = _is_save_interval_step(completed_steps, args.save_interval)
Expand Down
Loading
Loading