From a17f3589f4aaaa1afe0b4c9ce8cd09a2082a2dfb Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 4 Mar 2026 07:25:17 +0000 Subject: [PATCH 1/3] optimize metadata prepare Signed-off-by: ganyi --- atom/plugin/attention.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index cb46e6b5a..cc1873bf9 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -21,7 +21,6 @@ @dataclass class AiterFlashAttentionPhaseMetadata: max_query_len: int - min_query_len: int max_seq_len: int query_start_loc: torch.Tensor @@ -58,7 +57,6 @@ class AiterChunkContextMetadata: @dataclass class AiterFlashAttentionChunkPrefillMetadata: max_query_len: int - min_query_len: int max_seq_len: int query_start_loc: torch.Tensor chunk_context_metadata: AiterChunkContextMetadata @@ -300,17 +298,23 @@ def build( num_extend_tokens, num_prefill_tokens, ) = split_ret + prefill_only = num_decodes == 0 and num_extends == 0 and num_prefills > 0 + decode_only = num_decodes > 0 and num_extends == 0 and num_prefills == 0 + mixed_request = not (prefill_only or decode_only) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - seq_lens = common_attn_metadata.seq_lens.cpu() - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + if mixed_request: + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + else: + seq_lens = None + query_lens_cpu = None decode_metadata = None if num_decodes > 0: decode_metadata = AiterFlashAttentionDecodeMetadata( - max_query_len=query_lens_cpu[:num_decodes].max().item(), - min_query_len=query_lens_cpu[:num_decodes].min().item(), - max_seq_len=seq_lens[:num_decodes].max().item(), + max_query_len=common_attn_metadata.max_query_len if decode_only else query_lens_cpu[:num_decodes].max().item(), + max_seq_len=common_attn_metadata.max_seq_len if decode_only else seq_lens[:num_decodes].max().item(), query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], ) @@ -435,7 +439,6 @@ def build( ) extend_metadata = AiterFlashAttentionChunkPrefillMetadata( max_query_len=query_lens_for_extend.max().item(), - min_query_len=query_lens_for_extend.min().item(), max_seq_len=seq_lens[num_extends_slice].max().item(), query_start_loc=query_start_loc_device - query_start_loc_device[0], chunk_context_metadata=chunk_context_metadata, @@ -443,18 +446,17 @@ def build( prefill_metadata = None if num_prefills > 0: - query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :] query_start_loc_device = common_attn_metadata.query_start_loc[ num_decodes + num_extends : ] prefill_metadata = AiterFlashAttentionPrefillMetadata( - max_query_len=query_lens_for_prefill.max().item(), - min_query_len=query_lens_for_prefill.min().item(), - max_seq_len=seq_lens[num_decodes + num_extends :].max().item(), + max_query_len=common_attn_metadata.max_query_len if prefill_only else query_lens_cpu[:num_decodes].max().item(), + max_seq_len=common_attn_metadata.max_seq_len if prefill_only else query_lens_cpu[:num_decodes].max().item(), query_start_loc=query_start_loc_device - query_start_loc_device[0], ) - num_actual_kv_tokens = torch.sum(seq_lens).item() + # num_actual_kv_tokens = torch.sum(seq_lens).item() + num_actual_kv_tokens = 0 use_cascade = False From 8369d99b45d28e3ba7f208eebbfe60e59b0512bd Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 4 Mar 2026 07:25:38 +0000 Subject: [PATCH 2/3] black Signed-off-by: ganyi --- atom/plugin/attention.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index cc1873bf9..9e231f25d 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -313,8 +313,16 @@ def build( decode_metadata = None if num_decodes > 0: decode_metadata = AiterFlashAttentionDecodeMetadata( - max_query_len=common_attn_metadata.max_query_len if decode_only else query_lens_cpu[:num_decodes].max().item(), - max_seq_len=common_attn_metadata.max_seq_len if decode_only else seq_lens[:num_decodes].max().item(), + max_query_len=( + common_attn_metadata.max_query_len + if decode_only + else query_lens_cpu[:num_decodes].max().item() + ), + max_seq_len=( + common_attn_metadata.max_seq_len + if decode_only + else seq_lens[:num_decodes].max().item() + ), query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1], ) @@ -450,8 +458,16 @@ def build( num_decodes + num_extends : ] prefill_metadata = AiterFlashAttentionPrefillMetadata( - max_query_len=common_attn_metadata.max_query_len if prefill_only else query_lens_cpu[:num_decodes].max().item(), - max_seq_len=common_attn_metadata.max_seq_len if prefill_only else query_lens_cpu[:num_decodes].max().item(), + max_query_len=( + common_attn_metadata.max_query_len + if prefill_only + else query_lens_cpu[:num_decodes].max().item() + ), + max_seq_len=( + common_attn_metadata.max_seq_len + if prefill_only + else query_lens_cpu[:num_decodes].max().item() + ), query_start_loc=query_start_loc_device - query_start_loc_device[0], ) From 40722b7c5f27e1f6a8fbf5a516586a30f55aea1a Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 5 Mar 2026 03:07:08 +0000 Subject: [PATCH 3/3] fix prefill prepare bug Signed-off-by: ganyi --- atom/plugin/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atom/plugin/attention.py b/atom/plugin/attention.py index 9e231f25d..321c7e29e 100644 --- a/atom/plugin/attention.py +++ b/atom/plugin/attention.py @@ -461,12 +461,12 @@ def build( max_query_len=( common_attn_metadata.max_query_len if prefill_only - else query_lens_cpu[:num_decodes].max().item() + else query_lens_cpu[num_decodes + num_extends :].max().item() ), max_seq_len=( common_attn_metadata.max_seq_len if prefill_only - else query_lens_cpu[:num_decodes].max().item() + else query_lens_cpu[num_decodes + num_extends :].max().item() ), query_start_loc=query_start_loc_device - query_start_loc_device[0], )