From 0334383f4e65e8fea39443164a7707b2a93fe167 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 10 Nov 2025 08:18:17 -0500 Subject: [PATCH 01/11] working changes Signed-off-by: Thomas Parnell --- examples/kv_hacking.py | 46 +++++++ examples/offline_inference/spans/spans.py | 2 +- examples/test_stream.py | 20 +++ .../attention/ops/triton_unified_attention.py | 130 +++++++++++++++--- .../layers/rotary_embedding/base.py | 14 ++ vllm/v1/attention/backends/triton_attn.py | 6 + vllm/v1/attention/backends/utils.py | 2 + vllm/v1/worker/gpu_model_runner.py | 14 +- 8 files changed, 211 insertions(+), 23 deletions(-) create mode 100644 examples/kv_hacking.py create mode 100644 examples/test_stream.py diff --git a/examples/kv_hacking.py b/examples/kv_hacking.py new file mode 100644 index 0000000000..2c9657b07c --- /dev/null +++ b/examples/kv_hacking.py @@ -0,0 +1,46 @@ +import os + +from vllm import LLM, SamplingParams + +os.environ["VLLM_USE_V1"] = "1" + +llm = LLM( + model="facebook/opt-125m", + gpu_memory_utilization=0.4, + enforce_eager=True, + block_size=16, +) + +doc1 = "The Arsenal Football Club, commonly known as simply Arsenal, is a professional football club based in Islington, North London, England. They compete in the Premier League, the top tier of English football. In domestic football, Arsenal have won 13 league titles (including one unbeaten title), a record 14 FA Cups, two League Cups, 17 FA Community Shields, and a Football League Centenary Trophy. In European football, they have one European Cup Winners' Cup and one Inter-Cities Fairs Cup. In terms of trophies won, it is the third-most successful club in English football.[2]" +doc2 = "Switzerland,[d] officially the Swiss Confederation,[e] is a landlocked country located in west-central Europe.[f][13] It is bordered by Italy to the south, France to the west, Germany to the north, and Austria and Liechtenstein to the east. Switzerland is geographically divided among the Swiss Plateau, the Alps and the Jura; the Alps occupy the greater part of the territory, whereas most of the country's nearly 9 million people are concentrated on the plateau, which hosts its largest cities and economic centres, including Zurich, Geneva, and Lausanne.[14]" + +tokenizer = llm.get_tokenizer() + +N_BLOCKS = 2 + +tok1 = tokenizer(doc1)["input_ids"][: N_BLOCKS * 16] +tok2 = tokenizer(doc2)["input_ids"][: (N_BLOCKS * 16)] + +assert len(tok1) == N_BLOCKS * 16 +assert len(tok2) == N_BLOCKS * 16 + +prompt1 = {"prompt_token_ids": tok1} + +prompt2 = {"prompt_token_ids": tok2} + +prompt12 = {"prompt_token_ids": tok1 + tok2} + +# only do prefill +prefill_params = SamplingParams(temperature=0.0, max_tokens=1) + +print("----------- PREFILL PROMPT1 --------------:") +output = llm.generate(prompt1, prefill_params) +print(output) + +print("----------- PREFILL PROMPT2 --------------:") +output = llm.generate(prompt2, prefill_params) +print(output) + +print("----------- PREFILL PROMPT12 --------------:") +output = llm.generate(prompt12, prefill_params) +print(output) diff --git a/examples/offline_inference/spans/spans.py b/examples/offline_inference/spans/spans.py index ebe42f7ba7..2a2207718f 100644 --- a/examples/offline_inference/spans/spans.py +++ b/examples/offline_inference/spans/spans.py @@ -62,7 +62,7 @@ def main(): # enables block attention # -> when this line is not commented, we expect a speedup # in the execution of the last two .generate calls - os.environ["VLLM_V1_SPANS_ENABLED"] = "True" + os.environ["VLLM_V1_SPANS_ENABLED"] = "False" # the token that tells vLLM "this is the beginning of a span" os.environ["VLLM_V1_SPANS_TOKEN_PLUS"] = str(SPAN_TOK_PLUS) diff --git a/examples/test_stream.py b/examples/test_stream.py new file mode 100644 index 0000000000..62475c6d92 --- /dev/null +++ b/examples/test_stream.py @@ -0,0 +1,20 @@ +import asyncio +import time +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams + +engine_args = AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True) +model = AsyncLLMEngine.from_engine_args(engine_args) + +def generate_streaming(prompt): + results_generator = model.generate( + prompt, + SamplingParams(temperature=0.0, logprobs=1), + request_id=time.monotonic() + ) + for request_output in results_generator: + text = request_output.outputs[0].text + tokens = request_output.outputs[0].token_ids + logprobs = request_output.outputs[0].logprobs + print(text, tokens, logprobs) + +generate_streaming("hello") diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39b..4d3d7dc0ec 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -64,6 +64,7 @@ def kernel_unified_attention_2d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] + cos_sin_cache_ptr, # [max_model_len, head_size] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -94,6 +95,8 @@ def kernel_unified_attention_2d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + stride_cs_cache_0: tl.int64, # int + stride_cs_cache_1: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -102,6 +105,8 @@ def kernel_unified_attention_2d( FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): + + q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -128,20 +133,43 @@ def kernel_unified_attention_2d( query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( + + offs_d_new = tl.arange(0, HEAD_SIZE_PADDED // 2) + + query_offset_even = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] + + 2*offs_d_new[None, :] + ) + + query_offset_odd = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + 2*offs_d_new[None, :] + 1 ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + + dim_mask_even = tl.where(2*offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_odd = tl.where((2*offs_d_new+1) < HEAD_SIZE, 1, 0).to(tl.int1) + + dim_mask_a = tl.where(offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_M, HEAD_SIZE_PADDED) - Q = tl.load( - query_ptr + query_offset, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_1 : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_1 = tl.load( + query_ptr + query_offset_even, + mask=dim_mask_even[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + # Q_2 : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_2 = tl.load( + query_ptr + query_offset_odd, + mask=dim_mask_odd[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) @@ -225,7 +253,6 @@ def kernel_unified_attention_2d( physical_block_idx = tl.load( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE ).to(tl.int64) - v_offset = ( physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 @@ -233,27 +260,81 @@ def kernel_unified_attention_2d( + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) - k_offset = ( + + k_offset_even = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + 2*offs_d_new[:, None] * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load( - key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + k_offset_odd = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + (2*offs_d_new[:, None]+1) * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + + # K_even : (HEAD_SIZE, TILE_SIZE) + K_even_load = tl.load( + key_cache_ptr + k_offset_even, + mask=dim_mask_even[:, None] & tile_mask[None, :], other=0.0, ) - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load + if K_even_load.dtype.is_fp8(): + if Q_1.dtype.is_fp8(): + K_even = K_even_load else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + K_even = (K_even_load.to(tl.float32) * tl.load(k_scale)).to(Q_1.dtype) else: - K = K_load + K_even = K_even_load + + + # K_odd : (HEAD_SIZE, TILE_SIZE) + K_odd_load = tl.load( + key_cache_ptr + k_offset_odd, + mask=dim_mask_odd[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_odd_load.dtype.is_fp8(): + if Q_2.dtype.is_fp8(): + K_odd = K_odd_load + else: + K_odd = (K_odd_load.to(tl.float32) * tl.load(k_scale)).to(Q_2.dtype) + else: + K_odd = K_odd_load + + + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) + + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) + + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0 + ).to(K_even.dtype) + + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0 + ).to(K_even.dtype) + + K_rot1 = K_even * cos - K_odd * sin + K_rot2 = K_odd * cos + K_even * sin + + #K_rot1 = K_even + #K_rot2 = K_odd # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -274,8 +355,8 @@ def kernel_unified_attention_2d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - - S += scale * tl.dot(Q, K) + S += scale * tl.dot(Q_1, K_rot1) + S += scale * tl.dot(Q_2, K_rot2) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -754,7 +835,9 @@ def unified_attention( qq_bias=None, # Optional tensor for sinks sinks=None, + cos_sin_cache=None, ): + assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -794,7 +877,9 @@ def unified_attention( TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + #if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + if True: + kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -810,6 +895,7 @@ def unified_attention( seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, qq_bias_ptr=qq_bias, + cos_sin_cache_ptr=cos_sin_cache, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -840,6 +926,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), + stride_cs_cache_0=cos_sin_cache.stride(0), + stride_cs_cache_1=cos_sin_cache.stride(1), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2fc00130da..74989a13fa 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -112,7 +112,14 @@ def forward_native( positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) + + print("cos_sin: ", cos_sin.shape) + cos, sin = cos_sin.chunk(2, dim=-1) + + print("cos.shape: ", cos.shape) + print("sin.shape: ", sin.shape) + if invert_rotation_angle: sin = -sin @@ -123,6 +130,7 @@ def forward_native( query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + ''' # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape @@ -131,6 +139,8 @@ def forward_native( key_pass = key[..., self.rotary_dim :] key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + ''' + return query, key def forward_cuda( @@ -139,7 +149,11 @@ def forward_cuda( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + + return self.forward_native(positions, query, key) + if self.use_flashinfer: + assert False torch.ops.vllm.flashinfer_rotary_embedding( positions, query, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 0590a87bf8..0ef72bc48d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -59,6 +59,8 @@ class TritonAttentionMetadata: prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None + cos_sin_cache: torch.Tensor | None = None + # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None @@ -141,6 +143,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + cos_sin_cache=common_attn_metadata.cos_sin_cache, ) return attn_metadata @@ -343,6 +346,8 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + cos_sin_cache = attn_metadata.cos_sin_cache + unified_attention( q=query[:num_actual_tokens], k=key_cache, @@ -363,6 +368,7 @@ def forward( v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, output_scale=output_scale, + cos_sin_cache=cos_sin_cache, ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 07dfbc766a..04985c4bfd 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -82,6 +82,8 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + cos_sin_cache: torch.Tensor | None = None + causal: bool = True # Needed by FastPrefillAttentionBuilder diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d92b541b14..3a4cb2e241 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1413,6 +1413,17 @@ def _build_attention_metadata( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + if not hasattr(self, 'rotate'): + if not isinstance(self.model.model.layers[0], PPMissingLayer): + self.rotate = self.model.model.layers[ + 0].self_attn.rotary_emb + else: + for lay in self.model.model.layers: + if not isinstance(lay, PPMissingLayer): + self.rotate = lay.self_attn.rotary_emb + break + print(self.rotate) + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1430,6 +1441,7 @@ def _build_attention_metadata( causal=True, encoder_seq_lens=encoder_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens, + cos_sin_cache=self.rotate.cos_sin_cache ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -2639,7 +2651,7 @@ def execute_model( with record_function_or_nullcontext("Preprocess"): # NOTE(tdoublep): should this be inside context below? # handle repositioning requests - self._perform_repositioning(scheduler_output) + #self._perform_repositioning(scheduler_output) with self.synchronize_input_prep(): # Update persistent batch states. From d0d952ffd237404ff626bcfe573baea47f38cfe6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 10 Nov 2025 08:21:52 -0500 Subject: [PATCH 02/11] remove some prints Signed-off-by: Thomas Parnell --- examples/offline_inference/spans/spans.py | 2 +- vllm/model_executor/layers/rotary_embedding/base.py | 5 ----- vllm/v1/worker/gpu_model_runner.py | 1 - 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/offline_inference/spans/spans.py b/examples/offline_inference/spans/spans.py index 2a2207718f..908bf05efd 100644 --- a/examples/offline_inference/spans/spans.py +++ b/examples/offline_inference/spans/spans.py @@ -72,7 +72,7 @@ def main(): os.environ["VLLM_V1_SPANS_TOKEN_CROSS"] = str(SPAN_TOK_CROSS) # will print every step of the span process if set to true - os.environ["VLLM_V1_SPANS_DEBUG"] = "True" + os.environ["VLLM_V1_SPANS_DEBUG"] = "False" # will disable the adjustment of positional encodings when a KV cache # block is loaded to a different position than it was stored diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 74989a13fa..334d9d04e3 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -113,13 +113,8 @@ def forward_native( num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) - print("cos_sin: ", cos_sin.shape) - cos, sin = cos_sin.chunk(2, dim=-1) - print("cos.shape: ", cos.shape) - print("sin.shape: ", sin.shape) - if invert_rotation_angle: sin = -sin diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3a4cb2e241..bb72061aed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1422,7 +1422,6 @@ def _build_attention_metadata( if not isinstance(lay, PPMissingLayer): self.rotate = lay.self_attn.rotary_emb break - print(self.rotate) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, From 72261c2608ad4ebfb98cb2c2c2711763aec5e6d1 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 03:53:59 -0500 Subject: [PATCH 03/11] Use neox style in kernel Signed-off-by: Thomas Parnell --- .../attention/ops/triton_unified_attention.py | 115 ++++++++---------- 1 file changed, 52 insertions(+), 63 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 4d3d7dc0ec..45916a5edc 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -64,7 +64,7 @@ def kernel_unified_attention_2d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] - cos_sin_cache_ptr, # [max_model_len, head_size] + cos_sin_cache_ptr, # [max_model_len, head_size] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -95,8 +95,8 @@ def kernel_unified_attention_2d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int - stride_cs_cache_0: tl.int64, # int - stride_cs_cache_1: tl.constexpr, # int + stride_cs_cache_0: tl.int64, # int + stride_cs_cache_1: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -105,8 +105,6 @@ def kernel_unified_attention_2d( FP8_MIN: tl.constexpr = float8_info.min, FP8_MAX: tl.constexpr = float8_info.max, ): - - q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -136,40 +134,40 @@ def kernel_unified_attention_2d( offs_d_new = tl.arange(0, HEAD_SIZE_PADDED // 2) - query_offset_even = ( + query_offset_a = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + 2*offs_d_new[None, :] + + offs_d_new[None, :] ) - query_offset_odd = ( + query_offset_b = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + 2*offs_d_new[None, :] + 1 + + offs_d_new[None, :] + + HEAD_SIZE_PADDED // 2 ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) - dim_mask_even = tl.where(2*offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) - dim_mask_odd = tl.where((2*offs_d_new+1) < HEAD_SIZE, 1, 0).to(tl.int1) - dim_mask_a = tl.where(offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) - dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to( + tl.int1 + ) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q_1 : (BLOCK_M, HEAD_SIZE_PADDED // 2) - Q_1 = tl.load( - query_ptr + query_offset_even, - mask=dim_mask_even[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_a : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_a = tl.load( + query_ptr + query_offset_a, + mask=dim_mask_a[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) - # Q_2 : (BLOCK_M, HEAD_SIZE_PADDED // 2) - Q_2 = tl.load( - query_ptr + query_offset_odd, - mask=dim_mask_odd[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_b : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_b = tl.load( + query_ptr + query_offset_b, + mask=dim_mask_b[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) @@ -260,53 +258,49 @@ def kernel_unified_attention_2d( + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) - - k_offset_even = ( + k_offset_a = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + 2*offs_d_new[:, None] * stride_k_cache_3 + + offs_d_new[:, None] * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - k_offset_odd = ( + k_offset_b = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + (2*offs_d_new[:, None]+1) * stride_k_cache_3 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - - # K_even : (HEAD_SIZE, TILE_SIZE) - K_even_load = tl.load( - key_cache_ptr + k_offset_even, - mask=dim_mask_even[:, None] & tile_mask[None, :], + # K_a : (HEAD_SIZE, TILE_SIZE) + K_a_load = tl.load( + key_cache_ptr + k_offset_a, + mask=dim_mask_a[:, None] & tile_mask[None, :], other=0.0, ) - if K_even_load.dtype.is_fp8(): - if Q_1.dtype.is_fp8(): - K_even = K_even_load + if K_a_load.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): + K_a = K_a_load else: - K_even = (K_even_load.to(tl.float32) * tl.load(k_scale)).to(Q_1.dtype) + K_a = (K_a_load.to(tl.float32) * tl.load(k_scale)).to(Q_a.dtype) else: - K_even = K_even_load + K_a = K_a_load - - # K_odd : (HEAD_SIZE, TILE_SIZE) - K_odd_load = tl.load( - key_cache_ptr + k_offset_odd, - mask=dim_mask_odd[:, None] & tile_mask[None, :], + # K_b : (HEAD_SIZE, TILE_SIZE) + K_b_load = tl.load( + key_cache_ptr + k_offset_b, + mask=dim_mask_b[:, None] & tile_mask[None, :], other=0.0, ) - if K_odd_load.dtype.is_fp8(): - if Q_2.dtype.is_fp8(): - K_odd = K_odd_load + if K_b_load.dtype.is_fp8(): + if Q_b.dtype.is_fp8(): + K_b = K_b_load else: - K_odd = (K_odd_load.to(tl.float32) * tl.load(k_scale)).to(Q_2.dtype) + K_b = (K_b_load.to(tl.float32) * tl.load(k_scale)).to(Q_b.dtype) else: - K_odd = K_odd_load - + K_b = K_b_load cos_cache_offset = ( seq_offset[None, :] * stride_cs_cache_0 @@ -321,20 +315,17 @@ def kernel_unified_attention_2d( cos = tl.load( cos_sin_cache_ptr + cos_cache_offset, mask=dim_mask_a[:, None] & tile_mask[None, :], - other=0.0 - ).to(K_even.dtype) + other=0.0, + ).to(K_a.dtype) sin = tl.load( cos_sin_cache_ptr + sin_cache_offset, mask=dim_mask_b[:, None] & tile_mask[None, :], - other=0.0 - ).to(K_even.dtype) - - K_rot1 = K_even * cos - K_odd * sin - K_rot2 = K_odd * cos + K_even * sin + other=0.0, + ).to(K_b.dtype) - #K_rot1 = K_even - #K_rot2 = K_odd + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -344,10 +335,10 @@ def kernel_unified_attention_2d( ) if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): V = V_load else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q_a.dtype) else: V = V_load @@ -355,8 +346,8 @@ def kernel_unified_attention_2d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - S += scale * tl.dot(Q_1, K_rot1) - S += scale * tl.dot(Q_2, K_rot2) + S += scale * tl.dot(Q_a, K_rot_a) + S += scale * tl.dot(Q_b, K_rot_b) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -837,7 +828,6 @@ def unified_attention( sinks=None, cos_sin_cache=None, ): - assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -877,9 +867,8 @@ def unified_attention( TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill - #if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + # if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: if True: - kernel_unified_attention_2d[ ( total_num_q_blocks, From 3c34a5b9531ce00bb4b14fb1eb6a664ef6be6fe4 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 04:38:55 -0500 Subject: [PATCH 04/11] remove file Signed-off-by: Thomas Parnell --- examples/kv_hacking.py | 46 ------------------------------------------ 1 file changed, 46 deletions(-) delete mode 100644 examples/kv_hacking.py diff --git a/examples/kv_hacking.py b/examples/kv_hacking.py deleted file mode 100644 index 2c9657b07c..0000000000 --- a/examples/kv_hacking.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -from vllm import LLM, SamplingParams - -os.environ["VLLM_USE_V1"] = "1" - -llm = LLM( - model="facebook/opt-125m", - gpu_memory_utilization=0.4, - enforce_eager=True, - block_size=16, -) - -doc1 = "The Arsenal Football Club, commonly known as simply Arsenal, is a professional football club based in Islington, North London, England. They compete in the Premier League, the top tier of English football. In domestic football, Arsenal have won 13 league titles (including one unbeaten title), a record 14 FA Cups, two League Cups, 17 FA Community Shields, and a Football League Centenary Trophy. In European football, they have one European Cup Winners' Cup and one Inter-Cities Fairs Cup. In terms of trophies won, it is the third-most successful club in English football.[2]" -doc2 = "Switzerland,[d] officially the Swiss Confederation,[e] is a landlocked country located in west-central Europe.[f][13] It is bordered by Italy to the south, France to the west, Germany to the north, and Austria and Liechtenstein to the east. Switzerland is geographically divided among the Swiss Plateau, the Alps and the Jura; the Alps occupy the greater part of the territory, whereas most of the country's nearly 9 million people are concentrated on the plateau, which hosts its largest cities and economic centres, including Zurich, Geneva, and Lausanne.[14]" - -tokenizer = llm.get_tokenizer() - -N_BLOCKS = 2 - -tok1 = tokenizer(doc1)["input_ids"][: N_BLOCKS * 16] -tok2 = tokenizer(doc2)["input_ids"][: (N_BLOCKS * 16)] - -assert len(tok1) == N_BLOCKS * 16 -assert len(tok2) == N_BLOCKS * 16 - -prompt1 = {"prompt_token_ids": tok1} - -prompt2 = {"prompt_token_ids": tok2} - -prompt12 = {"prompt_token_ids": tok1 + tok2} - -# only do prefill -prefill_params = SamplingParams(temperature=0.0, max_tokens=1) - -print("----------- PREFILL PROMPT1 --------------:") -output = llm.generate(prompt1, prefill_params) -print(output) - -print("----------- PREFILL PROMPT2 --------------:") -output = llm.generate(prompt2, prefill_params) -print(output) - -print("----------- PREFILL PROMPT12 --------------:") -output = llm.generate(prompt12, prefill_params) -print(output) From 99fa372517a83a9e85055d9074e2c3da341670ba Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 04:40:17 -0500 Subject: [PATCH 05/11] remove file Signed-off-by: Thomas Parnell --- examples/test_stream.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 examples/test_stream.py diff --git a/examples/test_stream.py b/examples/test_stream.py deleted file mode 100644 index 62475c6d92..0000000000 --- a/examples/test_stream.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio -import time -from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams - -engine_args = AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True) -model = AsyncLLMEngine.from_engine_args(engine_args) - -def generate_streaming(prompt): - results_generator = model.generate( - prompt, - SamplingParams(temperature=0.0, logprobs=1), - request_id=time.monotonic() - ) - for request_output in results_generator: - text = request_output.outputs[0].text - tokens = request_output.outputs[0].token_ids - logprobs = request_output.outputs[0].logprobs - print(text, tokens, logprobs) - -generate_streaming("hello") From 23d30ccdd29155d6841b0154fbcc95e0b96b9c27 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 05:35:24 -0500 Subject: [PATCH 06/11] adapt forward_cuda Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/rotary_embedding/base.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 334d9d04e3..c9f369bbd2 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -125,7 +125,7 @@ def forward_native( query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - ''' + """ # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape @@ -134,7 +134,7 @@ def forward_native( key_pass = key[..., self.rotary_dim :] key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - ''' + """ return query, key @@ -144,11 +144,7 @@ def forward_cuda( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - - return self.forward_native(positions, query, key) - if self.use_flashinfer: - assert False torch.ops.vllm.flashinfer_rotary_embedding( positions, query, @@ -168,7 +164,7 @@ def forward_cuda( ops.rotary_embedding( positions, query, - key, + None, self.head_size, self.cos_sin_cache, self.is_neox_style, From 28f631113ceb34facebe653844297b500cc4dfea Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 10:22:02 -0500 Subject: [PATCH 07/11] Fuse rope into 3D kernel Signed-off-by: Thomas Parnell --- .../attention/ops/triton_unified_attention.py | 126 ++++++++++++++---- vllm/v1/core/block_pool.py | 68 +--------- vllm/v1/core/kv_cache_manager.py | 88 ++---------- vllm/v1/core/sched/output.py | 5 - vllm/v1/core/sched/scheduler.py | 18 +-- 5 files changed, 112 insertions(+), 193 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 45916a5edc..edd832e220 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -272,7 +272,7 @@ def kernel_unified_attention_2d( + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K_a : (HEAD_SIZE, TILE_SIZE) + # K_a : (HEAD_SIZE_PADDED // 2, TILE_SIZE) K_a_load = tl.load( key_cache_ptr + k_offset_a, mask=dim_mask_a[:, None] & tile_mask[None, :], @@ -287,7 +287,7 @@ def kernel_unified_attention_2d( else: K_a = K_a_load - # K_b : (HEAD_SIZE, TILE_SIZE) + # K_b : (HEAD_SIZE_PADDED // 2, TILE_SIZE) K_b_load = tl.load( key_cache_ptr + k_offset_b, mask=dim_mask_b[:, None] & tile_mask[None, :], @@ -438,6 +438,7 @@ def kernel_unified_attention_3d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] + cos_sin_cache_ptr, # [max_model_len, head_size] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -465,6 +466,8 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + stride_cs_cache_0: tl.int64, # int + stride_cs_cache_1: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -508,20 +511,43 @@ def kernel_unified_attention_3d( query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( + + offs_d_new = tl.arange(0, HEAD_SIZE_PADDED // 2) + + query_offset_a = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] + + offs_d_new[None, :] + ) + + query_offset_b = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d_new[None, :] + + HEAD_SIZE_PADDED // 2 ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + + dim_mask_a = tl.where(offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to( + tl.int1 + ) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_M, HEAD_SIZE_PADDED) - Q = tl.load( - query_ptr + query_offset, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_a : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_a = tl.load( + query_ptr + query_offset_a, + mask=dim_mask_a[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + # Q_b : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_b = tl.load( + query_ptr + query_offset_b, + mask=dim_mask_b[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) @@ -594,29 +620,76 @@ def kernel_unified_attention_3d( + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) - k_offset = ( + k_offset_a = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d_new[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + k_offset_b = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load( - key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + # K_a : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_a_load = tl.load( + key_cache_ptr + k_offset_a, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_a_load.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): + K_a = K_a_load + else: + K_a = (K_a_load.to(tl.float32) * tl.load(k_scale)).to(Q_a.dtype) + else: + K_a = K_a_load + + # K_b : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_b_load = tl.load( + key_cache_ptr + k_offset_b, + mask=dim_mask_b[:, None] & tile_mask[None, :], other=0.0, ) - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load + if K_b_load.dtype.is_fp8(): + if Q_b.dtype.is_fp8(): + K_b = K_b_load else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + K_b = (K_b_load.to(tl.float32) * tl.load(k_scale)).to(Q_b.dtype) else: - K = K_load + K_b = K_b_load - # V : (TILE_SIZE, HEAD_SIZE) + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) + + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) + + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_a.dtype) + + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_b.dtype) + + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin + + # V : (TILE_SIZE, HEAD_SIZE_PADDED) V_load = tl.load( value_cache_ptr + v_offset, mask=dim_mask[None, :] & tile_mask[:, None], @@ -624,10 +697,10 @@ def kernel_unified_attention_3d( ) if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): V = V_load else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q_a.dtype) else: V = V_load @@ -635,7 +708,8 @@ def kernel_unified_attention_3d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - S += scale * tl.dot(Q, K) + S += scale * tl.dot(Q_a, K_rot_a) + S += scale * tl.dot(Q_b, K_rot_b) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -867,8 +941,7 @@ def unified_attention( TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 # if batch contains a prefill - # if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: - if True: + if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -963,6 +1036,7 @@ def unified_attention( seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, qq_bias_ptr=qq_bias, + cos_sin_cache_ptr=cos_sin_cache, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -990,6 +1064,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), + stride_cs_cache_0=cos_sin_cache.stride(0), + stride_cs_cache_1=cos_sin_cache.stride(1), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 62be679309..55710ad5cc 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,7 +3,6 @@ from collections.abc import Iterable, Sequence from typing import Any -import vllm.envs as envs from vllm.distributed.kv_events import ( MEDIUM_GPU, AllBlocksCleared, @@ -242,8 +241,6 @@ def cache_full_blocks( if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) - self._set_block_positions(new_full_blocks, blocks, request) - if self.enable_kv_cache_events: if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None @@ -269,58 +266,6 @@ def cache_full_blocks( ) ) - def _set_block_positions( - self, - new_full_blocks: list[KVCacheBlock], - blocks: list[KVCacheBlock], - request: Request, - ): - """Sets the positions of new full blocks in the KV cache. - - This function assigns positions to newly filled blocks based - on their order within the provided block list. The position - corresponds to the location embedded in K vectors (if using RoPE) - in the KV cache and is critical for maintaining correct alignment, - especially when prompt positions differ between requests. - - Args: - new_full_blocks: List of KVCacheBlock objects that have been newly - filled and require position assignment. - blocks: List of all blocks associated with the current request, - used to determine the order in which positions are assigned. - request: The Request object containing token information for - debugging purposes. - - Note: - When VLLM_V1_SPANS_DEBUG is enabled, this function includes - debug logging that prints each block's tokens, to help - debug span-related workflows. - """ - pos = 0 - for blk in blocks: - if blk in new_full_blocks: - blk.position = pos - if envs.VLLM_V1_SPANS_DEBUG: - # this prints the tokens assigned to a new block - # in the KV cache - blk_tks = request.all_token_ids[pos : pos + 16] - assert blk.block_hash is not None - bhash = str(blk.block_hash)[:4] if blk.block_hash else None - print( - "[SPANS -> block_pool] assigning to pos", - pos, - "with hash", - bhash, - "block: ", - blk_tks, - ) - pos += 16 - if envs.VLLM_V1_SPANS_DEBUG: - print( - "[SPANS -> block_pool] assigned block count now ->", - len([b for b in self.blocks if b._block_hash]), - ) - def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -413,19 +358,8 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - # remove duplicates (blocks can now appear twice) - block_ids = set() - blocks_list_filtered = [] - for block in blocks_list: - if block.block_id not in block_ids: - blocks_list_filtered.append(block) - block_ids.add(block.block_id) self.free_block_queue.append_n( - [ - block - for block in blocks_list_filtered - if block.ref_cnt == 0 and not block.is_null - ] + [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] ) def reset_prefix_cache(self) -> bool: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 67eafe5d82..63a1ff06e4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Literal, overload -import vllm.envs as envs from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -18,13 +17,6 @@ logger = init_logger(__name__) -@dataclass -class BlockRepositionRequest: - block_id: int - kvc_pos: int - prompt_pos: int - - @dataclass class KVCacheBlocks: """ @@ -34,7 +26,6 @@ class KVCacheBlocks: """ blocks: tuple[Sequence[KVCacheBlock], ...] - blocks_to_reposition: list[BlockRepositionRequest] """ `blocks[i][j]` refers to the i-th kv_cache_group and the j-th block of tokens.We don't use block of @@ -55,8 +46,7 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": tuple( list(itertools.chain(blk1, blk2)) for blk1, blk2 in zip(self.blocks, other.blocks) - ), - self.blocks_to_reposition + other.blocks_to_reposition, + ) ) @overload @@ -97,7 +87,7 @@ def new_empty(self) -> "KVCacheBlocks": """ Creates a new KVCacheBlocks instance with no blocks. """ - return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))), []) + return KVCacheBlocks(tuple(() for _ in range(len(self.blocks)))) class KVCacheManager: @@ -159,7 +149,7 @@ def __init__( # # We use nested tuples to ensure the empty KVCacheBlocks is immutable. self.empty_kv_cache_blocks = KVCacheBlocks( - tuple(() for _ in range(self.num_kv_cache_groups)), [] + tuple(() for _ in range(self.num_kv_cache_groups)) ) @property @@ -215,58 +205,6 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: request.block_hashes, max_cache_hit_length ) ) - if envs.VLLM_V1_SPANS_DEBUG: - print( - "[SPANS -> kv_cache_manager] here's the blocks hashed in this request:", - [str(b)[:4] for b in request.block_hashes], - ) - kvcache_contents = [ - str(b.block_hash)[:4] if b.block_hash else None - for b in self.block_pool.blocks - if b._block_hash - ] - if len(kvcache_contents) > 32: - kvcache_contents = kvcache_contents[:32] + [ - "... (too long to print it all)" - ] - print( - "[SPANS -> kv_cache_manager] here's the contents of the kv cache:", - kvcache_contents, - ) - print( - "[SPANS -> kv_cache_manager] here's the number of blocks " - "that hit the cache:", - [ - str(b.block_hash)[:4] if b.block_hash else None - for b in computed_blocks[0] - ], - ) - - blocks_to_reposition = [] - if envs.VLLM_V1_SPANS_ENABLED: - # Spans does yet not support hybrid models - assert len(computed_blocks) == 1 - for i, b in enumerate(computed_blocks[0]): - prompt_pos = i * 16 - kvc_pos = b.position - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] checking block " - f"{b.block_id} with prompot pos {prompt_pos} " - f"and kv pos {kvc_pos}" - ) - assert isinstance(kvc_pos, int) - if kvc_pos != prompt_pos: - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} " - f"to prompt pos: {prompt_pos} repositioning needed" - ) - - blocks_to_reposition.append( - BlockRepositionRequest(b.block_id, kvc_pos, prompt_pos) - ) - b.position = int(prompt_pos) if self.log_stats: assert self.prefix_cache_stats is not None @@ -276,9 +214,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: preempted=request.num_preemptions > 0, ) - return self.create_kv_cache_blocks( - computed_blocks, blocks_to_reposition - ), num_new_computed_tokens + return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, @@ -384,7 +320,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return self.create_kv_cache_blocks(new_blocks, []) + return self.create_kv_cache_blocks(new_blocks) # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # num_new_tokens, but must exclude "non-committable" tokens (e.g., @@ -395,7 +331,7 @@ def allocate_slots( ) self.coordinator.cache_blocks(request, num_tokens_to_cache) - return self.create_kv_cache_blocks(new_blocks, []) + return self.create_kv_cache_blocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -467,7 +403,7 @@ def take_events(self) -> list[KVCacheEvent]: def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" - return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id), []) + return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id)) def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" @@ -479,13 +415,7 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: self.coordinator.cache_blocks(request, num_computed_tokens) def create_kv_cache_blocks( - self, - blocks: tuple[list[KVCacheBlock], ...], - blocks_to_reposition: list[BlockRepositionRequest], + self, blocks: tuple[list[KVCacheBlock], ...] ) -> KVCacheBlocks: # Only create new KVCacheBlocks for non-empty blocks - return ( - KVCacheBlocks(blocks, blocks_to_reposition) - if any(blocks) - else self.empty_kv_cache_blocks - ) + return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dff7ae389a..866136648b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -19,7 +19,6 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams - from vllm.v1.core.kv_cache_manager import BlockRepositionRequest from vllm.v1.request import Request else: KVConnectorMetadata = object @@ -28,7 +27,6 @@ PoolingParams = object SamplingParams = object Request = object - BlockRepositionRequest = object @bc_linter_include @@ -183,9 +181,6 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] - # for KV cache repositioning (as part of Block-Attention implementation) - blocks_to_reposition: list[BlockRepositionRequest] - # Whether the scheduled requests have all the output tokens they # need to perform grammar bitmask computation. pending_structured_output_tokens: bool = False diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a03c47e1c6..c17b19b58c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -6,7 +6,6 @@ from collections.abc import Iterable from typing import Any -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory @@ -22,11 +21,7 @@ EncoderCacheManager, compute_encoder_budget, ) -from vllm.v1.core.kv_cache_manager import ( - BlockRepositionRequest, - KVCacheBlocks, - KVCacheManager, -) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, @@ -359,7 +354,6 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) # Next, schedule the WAITING requests. - blocks_to_reposition: list[BlockRepositionRequest] = [] if not preempted_reqs: while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -417,15 +411,6 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) ) - # handle repositioning requests - if ( - envs.VLLM_V1_SPANS_ENABLED - and len(new_computed_blocks.blocks_to_reposition) > 0 - ): - blocks_to_reposition.extend( - new_computed_blocks.blocks_to_reposition - ) - # Get externally-cached tokens if using a KVConnector. if self.connector is not None: ext_tokens, load_kv_async = ( @@ -655,7 +640,6 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), - blocks_to_reposition=blocks_to_reposition, ) # NOTE(Kuntai): this function is designed for multiple purposes: From f71ef8baf445c54854d728a453f3400665df12eb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 11 Nov 2025 10:25:20 -0500 Subject: [PATCH 08/11] Remove block repos Signed-off-by: Thomas Parnell --- vllm/envs.py | 6 - .../layers/rotary_embedding/base.py | 5 - .../layers/rotary_embedding/mrope.py | 1 - vllm/v1/core/kv_cache_utils.py | 2 - vllm/v1/worker/gpu_model_runner.py | 121 +----------------- 5 files changed, 3 insertions(+), 132 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 67b508cb2d..97c5b77b9a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -228,7 +228,6 @@ VLLM_V1_SPANS_DEBUG: bool = False VLLM_V1_SPANS_TOKEN_PLUS: int = -1 VLLM_V1_SPANS_TOKEN_CROSS: int = -1 - VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False def get_default_cache_root(): @@ -1504,11 +1503,6 @@ def get_vllm_port() -> int | None: "VLLM_V1_SPANS_TOKEN_CROSS": lambda: int( os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1") ), - # for block-attention, detected spans will be loaded but not repositioned - "VLLM_V1_SPANS_DISABLE_REPOSITION": lambda: os.environ.get( - "VLLM_V1_SPANS_DISABLE_REPOSITION", "False" - ) - == "True", } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index c9f369bbd2..74784808cc 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -106,18 +106,13 @@ def forward_native( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, - invert_rotation_angle: bool = False, # <- to unrope kv's ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - if invert_rotation_angle: - sin = -sin - query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 34280c2d37..0592aa8f96 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -265,7 +265,6 @@ def forward_native( query: torch.Tensor, key: torch.Tensor | None = None, offsets: torch.Tensor | None = None, - invert_rotation_angle: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward(). diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index d6e2250f59..e2089cbf5f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -111,8 +111,6 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # Position (corresponds to positional encodings position) - position: int | None = None # The hash key (block hash + group id) of the block, only available # when the block is full and cached. _block_hash: BlockHashWithGroupId | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bb72061aed..1690db9f34 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1413,10 +1413,9 @@ def _build_attention_metadata( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) - if not hasattr(self, 'rotate'): + if not hasattr(self, "rotate"): if not isinstance(self.model.model.layers[0], PPMissingLayer): - self.rotate = self.model.model.layers[ - 0].self_attn.rotary_emb + self.rotate = self.model.model.layers[0].self_attn.rotary_emb else: for lay in self.model.model.layers: if not isinstance(lay, PPMissingLayer): @@ -1440,7 +1439,7 @@ def _build_attention_metadata( causal=True, encoder_seq_lens=encoder_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens, - cos_sin_cache=self.rotate.cos_sin_cache + cos_sin_cache=self.rotate.cos_sin_cache, ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -2192,116 +2191,6 @@ def _pool( pooler_output=pooler_output, ) - def _perform_repositioning(self, scheduler_output: "SchedulerOutput") -> None: - """ - Repositions KV cache blocks based on the scheduler's instructions. - - This method handles the repositioning of attention block - vectors in the KV cache when their positions in the KV cache - and in the prompt differ. It applies rotary embedding - transformations to adjust the positions. - - Args: - scheduler_output: The output from the scheduler containing blocks - to reposition. - """ - blocks_to_reposition = scheduler_output.blocks_to_reposition - if envs.VLLM_V1_SPANS_DEBUG: - ts_repo = time.time() - repo_count = len(blocks_to_reposition) - if len(blocks_to_reposition) > 0: - bs = 512 - for i in range(0, len(blocks_to_reposition), bs): - repo_batch = blocks_to_reposition[i : i + bs] - self._repositionings_handler(repo_batch) - if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: - torch.cuda.synchronize() - t_repo = time.time() - ts_repo - print( - f"[SPANS -> gpu_model_runner] repositioning" - f" speed: {repo_count / t_repo:.2f} (blocks/s)" - f" (total {repo_count})" - ) - - @torch.inference_mode() - def _repositionings_handler(self, blocks_to_reposition): - num_repos = len(blocks_to_reposition) - if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: - print(f"[SPANS -> gpu_model_runner] reposition block count: {num_repos}") - if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION: - kvc_positions = torch.tensor( - [d.kvc_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ).unsqueeze(-1) - prt_positions = torch.tensor( - [d.prompt_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ).unsqueeze(-1) - block_ids = torch.tensor( - [d.block_id for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ) - - # (self.kv_caches shape): - # [nlay, kv, maxblocks, blocksize, headcount, headsize] - concerned_vectors = [ - x[0, block_ids, :, :, :] for x in self.kv_caches - ] # -> [nlay, blockids, blocksize, headcount, headsize] - bids, bsize, hcount, hsize = concerned_vectors[0].shape - - template_tensor = torch.arange( - bsize, dtype=torch.long, device=self.kv_caches[0].device - ).unsqueeze(0) - pos_depos = kvc_positions + template_tensor - pos_repos = prt_positions + template_tensor - - # precision highly affects the outputs - PRECISION = torch.float32 - DEF_PRECISION = self.kv_caches[0].dtype - - # do the rotation - # note: PPMissingLayer is for pipeline parallel support - if not hasattr(self, "rotate"): - if not isinstance(self.model.model.layers[0], PPMissingLayer): - self.rotate = self.model.model.layers[0].self_attn.rotary_emb - else: - for lay in self.model.model.layers: - if not isinstance(lay, PPMissingLayer): - self.rotate = lay.self_attn.rotary_emb - break - assert pos_depos.shape[0] == concerned_vectors[0].shape[0] - - if num_repos > 100: - for i, k_vectors in enumerate(concerned_vectors): - k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos, k_vectors.to(PRECISION), invert_rotation_angle=True - ) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos, k_vectors_tmp - ) - self.kv_caches[i][0, block_ids, ...] = k_vectors_tmp.to( - DEF_PRECISION - ) - else: - nlays = len(concerned_vectors) - kvecs = torch.cat(concerned_vectors, dim=0).to(PRECISION) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos.repeat(nlays, 1), kvecs, invert_rotation_angle=True - ) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos.repeat(nlays, 1), k_vectors_tmp - ) - k_vectors_tmp = k_vectors_tmp.reshape( - nlays, *concerned_vectors[0].shape - ) - for i in range(len(self.kv_caches)): - self.kv_caches[i][0, block_ids, ...] = k_vectors_tmp[i].to( - DEF_PRECISION - ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: if ( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -2648,10 +2537,6 @@ def execute_model( ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("Preprocess"): - # NOTE(tdoublep): should this be inside context below? - # handle repositioning requests - #self._perform_repositioning(scheduler_output) - with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) From 2f6a7b24bd0398d5fd387feeb0d90a29841ad4eb Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Wed, 12 Nov 2025 10:22:56 -0500 Subject: [PATCH 09/11] bugfix: free block queue was being corrupted Signed-off-by: Nathan Ordonez --- vllm/v1/core/block_pool.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 55710ad5cc..9bfd8b69ee 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -358,8 +358,11 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 + # Remove duplicates while preserving order + dedup_bl = list({block.block_id: block for + block in blocks_list}.values()) self.free_block_queue.append_n( - [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] + [block for block in dedup_bl if block.ref_cnt == 0 and not block.is_null] ) def reset_prefix_cache(self) -> bool: From aa1416c244c62cafad7b798bc18e2c6b55eb167d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 13 Nov 2025 09:04:15 -0500 Subject: [PATCH 10/11] Fused RoPE only when spans enabled Signed-off-by: Thomas Parnell --- .../attention/ops/triton_unified_attention.py | 101 ++++++++++-------- .../layers/rotary_embedding/base.py | 7 +- vllm/v1/attention/backends/triton_attn.py | 8 +- 3 files changed, 67 insertions(+), 49 deletions(-) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index edd832e220..80154ea448 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -87,6 +87,7 @@ def kernel_unified_attention_2d( USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int + FUSE_ROPE: tl.constexpr, # bool stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int @@ -302,30 +303,34 @@ def kernel_unified_attention_2d( else: K_b = K_b_load - cos_cache_offset = ( - seq_offset[None, :] * stride_cs_cache_0 - + offs_d_new[:, None] * stride_cs_cache_1 - ) + if FUSE_ROPE: + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) - sin_cache_offset = ( - seq_offset[None, :] * stride_cs_cache_0 - + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 - ) + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) - cos = tl.load( - cos_sin_cache_ptr + cos_cache_offset, - mask=dim_mask_a[:, None] & tile_mask[None, :], - other=0.0, - ).to(K_a.dtype) + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_a.dtype) - sin = tl.load( - cos_sin_cache_ptr + sin_cache_offset, - mask=dim_mask_b[:, None] & tile_mask[None, :], - other=0.0, - ).to(K_b.dtype) + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_b.dtype) - K_rot_a = K_a * cos - K_b * sin - K_rot_b = K_b * cos + K_a * sin + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin + else: + K_rot_a = K_a + K_rot_b = K_b # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -458,6 +463,7 @@ def kernel_unified_attention_3d( USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int + FUSE_ROPE: tl.constexpr, # bool stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int @@ -664,30 +670,34 @@ def kernel_unified_attention_3d( else: K_b = K_b_load - cos_cache_offset = ( - seq_offset[None, :] * stride_cs_cache_0 - + offs_d_new[:, None] * stride_cs_cache_1 - ) + if FUSE_ROPE: + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) - sin_cache_offset = ( - seq_offset[None, :] * stride_cs_cache_0 - + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 - ) + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) - cos = tl.load( - cos_sin_cache_ptr + cos_cache_offset, - mask=dim_mask_a[:, None] & tile_mask[None, :], - other=0.0, - ).to(K_a.dtype) + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_a.dtype) - sin = tl.load( - cos_sin_cache_ptr + sin_cache_offset, - mask=dim_mask_b[:, None] & tile_mask[None, :], - other=0.0, - ).to(K_b.dtype) + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_b.dtype) - K_rot_a = K_a * cos - K_b * sin - K_rot_b = K_b * cos + K_a * sin + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin + else: + K_rot_a = K_a + K_rot_b = K_b # V : (TILE_SIZE, HEAD_SIZE_PADDED) V_load = tl.load( @@ -910,6 +920,7 @@ def unified_attention( use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None + fuse_rope = cos_sin_cache is not None block_size = v.shape[1] num_seqs = len(seqused_k) @@ -980,6 +991,7 @@ def unified_attention( USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), + FUSE_ROPE=fuse_rope, stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), @@ -988,8 +1000,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), - stride_cs_cache_0=cos_sin_cache.stride(0), - stride_cs_cache_1=cos_sin_cache.stride(1), + stride_cs_cache_0=cos_sin_cache.stride(0) if fuse_rope else 0, + stride_cs_cache_1=cos_sin_cache.stride(1) if fuse_rope else 0, query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, @@ -1056,6 +1068,7 @@ def unified_attention( USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), + FUSE_ROPE=fuse_rope, stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), @@ -1064,8 +1077,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), - stride_cs_cache_0=cos_sin_cache.stride(0), - stride_cs_cache_1=cos_sin_cache.stride(1), + stride_cs_cache_0=cos_sin_cache.stride(0) if fuse_rope else 0, + stride_cs_cache_1=cos_sin_cache.stride(1) if fuse_rope else 0, query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 74784808cc..50fd6e71ec 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch @@ -120,16 +121,14 @@ def forward_native( query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - """ # key may be None in some cases, e.g. cross-layer KV sharing - if key is not None: + if key is not None and not envs.VLLM_V1_SPANS_ENABLED: key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - """ return query, key @@ -159,7 +158,7 @@ def forward_cuda( ops.rotary_embedding( positions, query, - None, + None if envs.VLLM_V1_SPANS_ENABLED else key, self.head_size, self.cos_sin_cache, self.is_neox_style, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 0ef72bc48d..609e763077 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -7,6 +7,7 @@ import torch +import vllm.envs as envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -129,6 +130,11 @@ def build( suffix_kv_lens = None prefix_scheduler_metadata = None + if envs.VLLM_V1_SPANS_ENABLED: + cos_sin_cache = common_attn_metadata.cos_sin_cache + else: + cos_sin_cache = None + attn_metadata = TritonAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, @@ -143,7 +149,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, - cos_sin_cache=common_attn_metadata.cos_sin_cache, + cos_sin_cache=cos_sin_cache, ) return attn_metadata From e9d302c1972a96a113ea573053853260ceef3766 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 20 Nov 2025 14:06:27 -0500 Subject: [PATCH 11/11] Minor things Signed-off-by: Thomas Parnell --- examples/offline_inference/spans/spans.py | 4 ++-- vllm/model_executor/layers/rotary_embedding/base.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/spans/spans.py b/examples/offline_inference/spans/spans.py index 908bf05efd..ebe42f7ba7 100644 --- a/examples/offline_inference/spans/spans.py +++ b/examples/offline_inference/spans/spans.py @@ -62,7 +62,7 @@ def main(): # enables block attention # -> when this line is not commented, we expect a speedup # in the execution of the last two .generate calls - os.environ["VLLM_V1_SPANS_ENABLED"] = "False" + os.environ["VLLM_V1_SPANS_ENABLED"] = "True" # the token that tells vLLM "this is the beginning of a span" os.environ["VLLM_V1_SPANS_TOKEN_PLUS"] = str(SPAN_TOK_PLUS) @@ -72,7 +72,7 @@ def main(): os.environ["VLLM_V1_SPANS_TOKEN_CROSS"] = str(SPAN_TOK_CROSS) # will print every step of the span process if set to true - os.environ["VLLM_V1_SPANS_DEBUG"] = "False" + os.environ["VLLM_V1_SPANS_DEBUG"] = "True" # will disable the adjustment of positional encodings when a KV cache # block is loaded to a different position than it was stored diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 50fd6e71ec..d85ded63fa 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -129,7 +129,6 @@ def forward_native( key_pass = key[..., self.rotary_dim :] key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key def forward_cuda(