From ef59a8fc0881a05dea251c48cd8ee7da3cc91f83 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 9 Sep 2025 13:57:40 -0400 Subject: [PATCH 1/9] Initial supports for spans/block-attention. Co-authored-by: Nathan Ordonez Signed-off-by: Thomas Parnell Signed-off-by: Nathan Ordonez --- README.md | 6 + examples/offline_inference/spans/spans.py | 192 ++++++++++++++++++ vllm/envs.py | 31 +++ .../layers/rotary_embedding/base.py | 3 + .../layers/rotary_embedding/mrope.py | 1 + vllm/v1/core/block_pool.py | 53 ++++- vllm/v1/core/kv_cache_manager.py | 78 ++++++- vllm/v1/core/kv_cache_utils.py | 30 +++ vllm/v1/core/sched/output.py | 4 + vllm/v1/core/sched/scheduler.py | 12 +- vllm/v1/worker/gpu_model_runner.py | 112 ++++++++++ 11 files changed, 512 insertions(+), 10 deletions(-) create mode 100644 examples/offline_inference/spans/spans.py diff --git a/README.md b/README.md index 4e03df758c..0764bc9c6b 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Easy, fast, and cheap LLM serving for everyone --- +## What is the purpose of this fork? + +This is a fork of vLLM which we are using to develop support for *span semantics*. + +--- + *Latest News* 🔥 - [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). diff --git a/examples/offline_inference/spans/spans.py b/examples/offline_inference/spans/spans.py new file mode 100644 index 0000000000..ebe42f7ba7 --- /dev/null +++ b/examples/offline_inference/spans/spans.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import time + +# to ensure deterministic behaviour +os.environ["TOKENIZERS_PARALLELISM"] = "False" + +# standard imports +from vllm import LLM, SamplingParams +from vllm.inputs import TokensPrompt + + +# helper functions +def pad(toklist, padtok): + return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:] + + +def avg(list_of_numbers): + return sum(list_of_numbers) / max(len(list_of_numbers), 1) + + +def wrap(prompt): + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + + +def initialize_vllm( + model, temp=0.6, logprobs=None, max_toks=32768, max_generated_toks=1 +): + # boot up vLLM + samp_params_preload = SamplingParams(temperature=temp, max_tokens=1) + samp_params_generate = SamplingParams( + temperature=temp, max_tokens=max_generated_toks, logprobs=logprobs + ) + llm = LLM( + model=model, + gpu_memory_utilization=0.9, + enforce_eager=True, # <- so it boots faster + block_size=16, + ) + tok = llm.get_tokenizer() + tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x)) + return samp_params_preload, samp_params_generate, tok_fun, llm + + +def main(): + model_names = [ + "ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention + "ldsjmdy/Tulu3-RAG", # <- baseline + ] + model_name = model_names[0] + + # tokens that need to be set to perform block-attention + PAD_TOK = 27 # <- "<" + SPAN_TOK_PLUS = 10 # <- "+" + SPAN_TOK_CROSS = 31 # <- "@" + + # vLLM-specific env vars + + # enables block attention + # -> when this line is not commented, we expect a speedup + # in the execution of the last two .generate calls + os.environ["VLLM_V1_SPANS_ENABLED"] = "True" + + # the token that tells vLLM "this is the beginning of a span" + os.environ["VLLM_V1_SPANS_TOKEN_PLUS"] = str(SPAN_TOK_PLUS) + + # token that tells vLLM: + # "from here on, recompute KV vectors if any previous tokens differ" + os.environ["VLLM_V1_SPANS_TOKEN_CROSS"] = str(SPAN_TOK_CROSS) + + # will print every step of the span process if set to true + os.environ["VLLM_V1_SPANS_DEBUG"] = "True" + + # will disable the adjustment of positional encodings when a KV cache + # block is loaded to a different position than it was stored + # -> when this line is not commented, + # spans overlap in their positional encodings + os.environ["VLLM_V1_SPANS_DISABLE_REPOSITION"] = "True" + + # general env vars + + # now we instantiate the model + samp_params_preload, samp_params_generate, tok, llm = initialize_vllm( + model_name, max_generated_toks=128, max_toks=10_000, temp=0.0 + ) + + # components of the prompt template + prefix = pad( + tok( + "<|system|>\nYou are an intelligent AI assistant. " + "Please answer questions based on the user's instructions. " + "Below are some reference documents that may help you in " + "answering the user's question." + ), + PAD_TOK, + ) + midfx = [SPAN_TOK_CROSS] + tok( + "<|user|>\nPlease write a high-quality answer for the " + "given question using only the provided search documents " + "(some of which might be irrelevant).\nQuestion: " + ) + postfx = tok("""\n<|assistant|>\n""") + + print("---->", postfx) + + # task-specific documents + doc_a = pad( + [SPAN_TOK_PLUS] + + tok( + "[0] The Template-Assisted " + "Selective Epitaxy (TASE) method, developed at " + "IBM Research Europe – Zurich, permits to " + "create a homogeneous integration route for " + "various semiconductor materials which is " + "compatible with the CMOS process." + ), + PAD_TOK, + ) + + doc_b = pad( + [SPAN_TOK_PLUS] + + tok( + "[1] The dominant sequence transduction " + "models are based on complex recurrent or " + "convolutional neural networks in an encoder-decoder " + "configuration. " + ), + PAD_TOK, + ) + + # # alt-docs (purely to check performance on longer documents) + """ + a_toks = tok("Sequence Transduction Models") + b_toks = tok("Template-Assisted Selective Epitaxy") + doc_a = pad( + [SPAN_TOK_PLUS] + + [a_toks[idx % len(a_toks)] for idx in range(10_000)], + PAD_TOK, + ) + doc_b = pad( + [SPAN_TOK_PLUS] + + [b_toks[idx % len(a_toks)] for idx in range(10_000)], + PAD_TOK, + ) + """ + + # user query + query = ( + midfx + + tok( + "Tell me which one concerns deep learning. " + "Indicate your answer with a number in brackets." + ) + + postfx + ) + + # preload documents + ts_pre = time.time() + llm.generate( + [wrap(doc_a), wrap(doc_b), wrap(prefix)], sampling_params=samp_params_preload + ) + te_pre = time.time() - ts_pre + + ts_gen = time.time() + + # this now will load prefix, doc_a, doc_b, + # from the KV cache regardless of the order + model_response_1 = llm.generate( + wrap(prefix + doc_a + doc_b + query), + sampling_params=samp_params_generate, + use_tqdm=False, + ) + + # this should also run faster: + model_response_2 = llm.generate( + wrap(prefix + doc_b + doc_a + query), + sampling_params=samp_params_generate, + use_tqdm=False, + ) + + te_gen = time.time() - ts_gen + + print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)") + print("model output 1 was:", model_response_1[0].outputs[0].text) + print("model output 2 was:", model_response_2[0].outputs[0].text) + + +if __name__ == "__main__": + main() diff --git a/vllm/envs.py b/vllm/envs.py index 8d199da45b..5d88c1d39f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -172,6 +172,12 @@ VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + # spans vars + VLLM_V1_SPANS_ENABLED: bool = False + VLLM_V1_SPANS_DEBUG: bool = False + VLLM_V1_SPANS_TOKEN_PLUS: int = -1 + VLLM_V1_SPANS_TOKEN_CROSS: int = -1 + VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False def get_default_cache_root(): @@ -1221,6 +1227,31 @@ def get_vllm_port() -> Optional[int]: # raw bytes. Defaults to True for backward compatibility. "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + + # whether to enable block-attention (span detection, fan-in, repositioning) + "VLLM_V1_SPANS_ENABLED": + lambda: os.environ.get("VLLM_V1_SPANS_ENABLED", "False") == "True", + + # whether to print details pertaining to the block-attention + # implementation + "VLLM_V1_SPANS_DEBUG": + lambda: os.environ.get("VLLM_V1_SPANS_DEBUG", "False") == "True", + + # for block-attention, the token that will be used in order to + # indicate the beginning of a span (needed for it to work) + "VLLM_V1_SPANS_TOKEN_PLUS": + lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_PLUS", "-1")), + + # for block-attention, a token that signals the beginning of a + # span which needs to depend on all previous tokens + "VLLM_V1_SPANS_TOKEN_CROSS": + lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1")), + + # for block-attention, detected spans will be loaded but not repositioned + "VLLM_V1_SPANS_DISABLE_REPOSITION": + lambda: os.environ.get("VLLM_V1_SPANS_DISABLE_REPOSITION", "False" + ) == "True", + } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index be25e90abf..69eef7ba9e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -63,6 +63,7 @@ def forward_native( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, + invert_rotation_angle: bool = False # <- to unrope kv's ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" if offsets is not None: @@ -71,6 +72,8 @@ def forward_native( num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) + if invert_rotation_angle: + sin = -sin query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375..4601f147ea 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -230,6 +230,7 @@ def forward_native( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, + invert_rotation_angle: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d1e1c1c8d0..9069a364db 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from typing import Optional +import vllm.envs as envs from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) @@ -145,6 +146,8 @@ def cache_full_blocks( if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) + self._set_block_positions(new_full_blocks, blocks, request) + if self.enable_kv_cache_events: if num_cached_blocks == 0: parent_block_hash: Optional[ExternalBlockHash] = None @@ -167,6 +170,47 @@ def cache_full_blocks( medium=MEDIUM_GPU, )) + def _set_block_positions(self, new_full_blocks: list[KVCacheBlock], + blocks: list[KVCacheBlock], request: Request): + """Sets the positions of new full blocks in the KV cache. + + This function assigns positions to newly filled blocks based + on their order within the provided block list. The position + corresponds to the location embedded in K vectors (if using RoPE) + in the KV cache and is critical for maintaining correct alignment, + especially when prompt positions differ between requests. + + Args: + new_full_blocks: List of KVCacheBlock objects that have been newly + filled and require position assignment. + blocks: List of all blocks associated with the current request, + used to determine the order in which positions are assigned. + request: The Request object containing token information for + debugging purposes. + + Note: + When VLLM_V1_SPANS_DEBUG is enabled, this function includes + debug logging that prints each block's tokens, to help + debug span-related workflows. + """ + pos = 0 + for blk in blocks: + if blk in new_full_blocks: + blk.position = pos + if envs.VLLM_V1_SPANS_DEBUG: + # this prints the tokens assigned to a new block + # in the KV cache + blk_tks = request.all_token_ids[pos:pos + 16] + assert blk.block_hash is not None + bhash = str(abs(blk.block_hash.block_hash.hash_value) + )[:4] if blk.block_hash.block_hash else None + print('[SPANS -> block_pool] assigning to pos', pos, + 'with hash', bhash, 'block: ', blk_tks) + pos += 16 + if envs.VLLM_V1_SPANS_DEBUG: + print('[SPANS -> block_pool] assigned block count now ->', + len([b for b in self.blocks if b._block_hash])) + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -261,8 +305,15 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 + # remove duplicates (blocks can now appear twice) + block_ids = set() + blocks_list_filtered = [] + for block in blocks_list: + if block.block_id not in block_ids: + blocks_list_filtered.append(block) + block_ids.add(block.block_id) self.free_block_queue.append_n([ - block for block in blocks_list + block for block in blocks_list_filtered if block.ref_cnt == 0 and not block.is_null ]) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3a0fbb5e5c..a6628cfc55 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Literal, Optional, overload +import vllm.envs as envs from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -15,6 +16,13 @@ logger = init_logger(__name__) +@dataclass +class BlockRepositionRequest: + block_id: int + kvc_pos: int + prompt_pos: int + + @dataclass class KVCacheBlocks: """ @@ -23,6 +31,7 @@ class KVCacheBlocks: structure from the Scheduler. """ blocks: tuple[list[KVCacheBlock], ...] + blocks_to_reposition: list[BlockRepositionRequest] """ blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. We don't use block of tokens as the outer dimension because it assumes all @@ -35,7 +44,8 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + for blk1, blk2 in zip(self.blocks, other.blocks)), + self.blocks_to_reposition + other.blocks_to_reposition) @overload def get_block_ids( @@ -78,7 +88,7 @@ def get_unhashed_block_ids(self) -> list[int]: def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) + return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))), []) class KVCacheManager: @@ -180,6 +190,57 @@ def get_computed_blocks(self, computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) + if envs.VLLM_V1_SPANS_DEBUG: + print( + "[SPANS -> kv_cache_manager] here's the blocks hashed in " \ + "this request:", + [str(abs(b.hash_value))[:4] for b in request.block_hashes]) + kvcache_contents = [ + str(abs(b.block_hash.block_hash.hash_value))[:4] + if b.block_hash else None for b in self.block_pool.blocks + if b._block_hash + ] + if len(kvcache_contents) > 32: + kvcache_contents = kvcache_contents[:32] + [ + '... (too long to print it all)' + ] + print( + "[SPANS -> kv_cache_manager] here's the contents of the " \ + "kv cache:", + kvcache_contents) + print( + "[SPANS -> kv_cache_manager] here's the number of blocks " \ + "that hit the cache:", + [ + str(abs(b.block_hash.block_hash.hash_value))[:4] + if b.block_hash else None for b in computed_blocks[0] + ]) + + blocks_to_reposition = [] + if envs.VLLM_V1_SPANS_ENABLED: + # Spans does yet not support hybrid models + assert len(computed_blocks) == 1 + for i, b in enumerate(computed_blocks[0]): + prompt_pos = i * 16 + kvc_pos = b.position + if envs.VLLM_V1_SPANS_DEBUG: + print( + f"[SPANS -> kv_cache_manager] checking block " \ + f"{b.block_id} with prompot pos {prompt_pos} " \ + f"and kv pos {kvc_pos}" + ) + assert isinstance(kvc_pos, int) + if kvc_pos != prompt_pos: + if envs.VLLM_V1_SPANS_DEBUG: + print( + f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} "\ + f"to prompt pos: {prompt_pos} repositioning needed" + ) + + blocks_to_reposition.append( + BlockRepositionRequest(b.block_id, kvc_pos, + prompt_pos)) + b.position = int(prompt_pos) if self.log_stats: assert self.prefix_cache_stats is not None @@ -187,7 +248,8 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks), num_new_computed_tokens + return KVCacheBlocks(computed_blocks, blocks_to_reposition),\ + num_new_computed_tokens def allocate_slots( self, @@ -290,7 +352,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, []) # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # num_new_tokens, but must exclude "non-committable" tokens (e.g., @@ -300,7 +362,7 @@ def allocate_slots( request.num_tokens) self.coordinator.cache_blocks(request, num_tokens_to_cache) - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, []) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -381,7 +443,7 @@ def take_events(self) -> list[KVCacheEvent]: def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" - return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + return KVCacheBlocks(self.coordinator.get_blocks(request_id), []) def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" @@ -394,5 +456,5 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks( + tuple([] for _ in range(self.num_kv_cache_groups)), []) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2c0eac3ddd..077b09f4ab 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -162,6 +162,8 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 + # Position (corresponds to positional encodings position) + position: Optional[int] = None # The hash key (block hash + group id) of the block, only available # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None @@ -559,12 +561,38 @@ def hash_block_tokens( if not parent_block_hash: parent_block_hash = NONE_HASH + if envs.VLLM_V1_SPANS_ENABLED: + if envs.VLLM_V1_SPANS_TOKEN_PLUS == -1: + raise Exception( + '[SPANS -> kv_cache_utils]: span separator token undefined!') + # if a block starts with the span separator token, then its hash + # should be independent of previous tokens + firstok = curr_block_token_ids[0] + if firstok == envs.VLLM_V1_SPANS_TOKEN_PLUS: + if envs.VLLM_V1_SPANS_DEBUG: + print(f'[SPANS -> kv_cache_utils] detected span separator " \ + "token {envs.VLLM_V1_SPANS_TOKEN_PLUS} -> enable fan-in' + ) + parent_block_hash = NONE_HASH + curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) +def recompute_token_handler( + block_first_token: int, tokens_up_to_block: list[int], + extra_keys: Union[tuple[Any, ...], + None]) -> Union[tuple[Any, ...], None]: + if envs.VLLM_V1_SPANS_ENABLED and \ + block_first_token == envs.VLLM_V1_SPANS_TOKEN_CROSS: + tok_tuple = tuple(tokens_up_to_block) + extra_keys = (*extra_keys, tok_tuple) if extra_keys \ + else tok_tuple + return extra_keys + + def get_request_block_hasher( block_size: int, caching_hash_fn: Callable[[Any], bytes], @@ -600,6 +628,8 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + extra_keys = recompute_token_handler( + block_tokens[0], block_tokens[:start_token_idx], extra_keys) block_hash = hash_block_tokens(caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b5cd6c5c8a..eacfaaa1ff 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -16,6 +16,7 @@ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + from vllm.v1.core.kv_cache_manager import BlockRepositionRequest from vllm.v1.request import Request @@ -153,5 +154,8 @@ class SchedulerOutput: # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + # for KV cache repositioning (as part of Block-Attention implementation) + blocks_to_reposition: list[BlockRepositionRequest] + # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d40e96632..dbcb7ed39f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -9,6 +9,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -19,7 +20,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.kv_cache_manager import (BlockRepositionRequest, + KVCacheBlocks, KVCacheManager) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -330,6 +332,7 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) # Next, schedule the WAITING requests. + blocks_to_reposition: list[BlockRepositionRequest] = [] if not preempted_reqs: while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -381,6 +384,12 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks( request) + # handle repositioning requests + if envs.VLLM_V1_SPANS_ENABLED and \ + len(new_computed_blocks.blocks_to_reposition) > 0: + blocks_to_reposition.extend( + new_computed_blocks.blocks_to_reposition) + # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( @@ -589,6 +598,7 @@ def schedule(self) -> SchedulerOutput: get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, + blocks_to_reposition=blocks_to_reposition, ) # NOTE(Kuntai): this function is designed for multiple purposes: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bb..420df37f14 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -44,6 +44,7 @@ supports_transcription) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) +from vllm.model_executor.models.utils import PPMissingLayer from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) @@ -1587,6 +1588,113 @@ def _pool( kv_connector_output=kv_connector_output, ) + def _perform_repositioning(self, + scheduler_output: "SchedulerOutput") -> None: + """ + Repositions KV cache blocks based on the scheduler's instructions. + + This method handles the repositioning of attention block + vectors in the KV cache when their positions in the KV cache + and in the prompt differ. It applies rotary embedding + transformations to adjust the positions. + + Args: + scheduler_output: The output from the scheduler containing blocks + to reposition. + """ + blocks_to_reposition = scheduler_output.blocks_to_reposition + if envs.VLLM_V1_SPANS_DEBUG: + ts_repo = time.time() + repo_count = len(blocks_to_reposition) + if len(blocks_to_reposition) < 600: + self._repositionings_handler(blocks_to_reposition) + else: + bs = 400 + for i in range(len(blocks_to_reposition) // bs): + j = bs if i + bs * 2 < len( + blocks_to_reposition) else i + bs * 2 + repo_batch = blocks_to_reposition[i:j] + self._repositionings_handler(repo_batch) + if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: + torch.cuda.synchronize() + t_repo = time.time() - ts_repo + print(f'[SPANS -> gpu_model_runner] repositioning' \ + f' speed: {repo_count/t_repo:.2f} (blocks/s)') + + @torch.inference_mode() + def _repositionings_handler(self, blocks_to_reposition): + num_repos = len(blocks_to_reposition) + if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: + print( + f'[SPANS -> gpu_model_runner] ' \ + f'reposition block count: {num_repos}' + ) + if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION and num_repos > 0: + kvc_positions = torch.tensor( + [d.kvc_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + prt_positions = torch.tensor( + [d.prompt_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + block_ids = torch.tensor( + [d.block_id for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device) + + # (self.kv_caches shape): + # [nlay, kv, maxblocks, blocksize, headcount, headsize] + concerned_vectors = [ + x[0, block_ids, :, :, :] for x in self.kv_caches + ] # -> [nlay, blockids, blocksize, headcount, headsize] + bids, bsize, hcount, hsize = concerned_vectors[0].shape + + template_tensor = torch.arange( + bsize, dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(0) + pos_depos = kvc_positions + template_tensor + pos_repos = prt_positions + template_tensor + + # precision highly affects the outputs + PRECISION = torch.float32 + DEF_PRECISION = self.kv_caches[0].dtype + + # do the rotation + # note: PPMissingLayer is for pipeline parallel support + if not hasattr(self, 'rotate'): + if not isinstance(self.model.model.layers[0], PPMissingLayer): + self.rotate = self.model.model.layers[ + 0].self_attn.rotary_emb + else: + for lay in self.model.model.layers: + if not isinstance(lay, PPMissingLayer): + self.rotate = lay.self_attn.rotary_emb + break + assert pos_depos.shape[0] == concerned_vectors[0].shape[0] + + if num_repos > 100: + for i, k_vectors in enumerate(concerned_vectors): + k_vectors_tmp, _ = self.rotate.forward_native( + pos_depos, + k_vectors.to(PRECISION), + invert_rotation_angle=True) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_repos, k_vectors_tmp) + self.kv_caches[i][0, block_ids, ...] = \ + k_vectors_tmp.to(DEF_PRECISION) + else: + k_vectors_tmp, _ = self.rotate.forward_native( + pos_depos, + torch.cat([k.unsqueeze(0) for k in concerned_vectors], + dim=0).to(PRECISION), + invert_rotation_angle=True) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_repos, k_vectors_tmp) + for i in range(len(self.kv_caches)): + self.kv_caches[i][0, block_ids, ...] = \ + k_vectors_tmp[i].to(DEF_PRECISION) + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -1850,6 +1958,10 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with record_function_or_nullcontext("Preprocess"): + + # handle repositioning requests + self._perform_repositioning(scheduler_output) + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): From 6491b428172ffac00c6cc9957f5e2647671e92d0 Mon Sep 17 00:00:00 2001 From: "Nathan A. Ordonez Cardenas" Date: Wed, 17 Sep 2025 06:39:48 -0400 Subject: [PATCH 2/9] initial impl (runs, but accuracy dropped) Signed-off-by: Nathan Ordonez --- vllm/v1/core/block_pool.py | 21 ++- vllm/v1/core/kv_cache_manager.py | 82 ++++++---- vllm/v1/core/sched/scheduler.py | 9 +- vllm/v1/core/single_type_kv_cache_manager.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 148 +++++++++++++++++-- 5 files changed, 212 insertions(+), 53 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9069a364db..908e1852f0 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -73,9 +73,18 @@ def __init__( self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] + def _closest_cache_hit( + self, cached_blocks: dict[int, KVCacheBlock], + position: int, + ) -> dict[int, KVCacheBlock]: + return sorted( + list(cached_blocks.values()), + key=lambda x: abs(x.position - position))[0] + def get_cached_block( self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: + kv_cache_group_ids: list[int], + position: Optional[int] = None) -> Optional[list[KVCacheBlock]]: """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -95,7 +104,11 @@ def get_cached_block( block_hash_with_group_id) if not cached_blocks_one_group: return None - first_block = next(iter(cached_blocks_one_group.values())) + if position is not None and len(cached_blocks_one_group) > 1: + first_block = self._closest_cache_hit(cached_blocks_one_group, + position) + else: + first_block = next(iter(cached_blocks_one_group.values())) cached_blocks.append(first_block) return cached_blocks @@ -202,8 +215,8 @@ def _set_block_positions(self, new_full_blocks: list[KVCacheBlock], # in the KV cache blk_tks = request.all_token_ids[pos:pos + 16] assert blk.block_hash is not None - bhash = str(abs(blk.block_hash.block_hash.hash_value) - )[:4] if blk.block_hash.block_hash else None + bhash = str(blk.block_hash + )[:4] if blk.block_hash else None print('[SPANS -> block_pool] assigning to pos', pos, 'with hash', bhash, 'block: ', blk_tks) pos += 16 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a6628cfc55..0e63cb0f9a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -18,9 +18,11 @@ @dataclass class BlockRepositionRequest: - block_id: int - kvc_pos: int prompt_pos: int + cached_pos: int + cached_blockid: int + prompt_blockpos: int + prompt_reqid: str @dataclass @@ -190,13 +192,47 @@ def get_computed_blocks(self, computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) + + # now we check how many of those computed blocks have incorrect or are + # after an incorrect position match + # our own positions are clear, now we need to compare that to cached + # positions + repo_reqs = [] + non_match_idx = -1 + non_match_found = False + for i, block in enumerate(computed_blocks[0]): + if block.is_null: # null blocks don't have meaningful position + continue + prompt_pos = self.block_size * i + cached_pos = block.position + # find first block id where pos didn't match + if prompt_pos != cached_pos and not non_match_found: + non_match_found = True + non_match_idx = i + # record from then on and after, repo requests + if non_match_found: + repo_reqs.append( + BlockRepositionRequest( + prompt_pos, + cached_pos, + block.block_id, + i, + request.request_id)) + # if any repo is needed, we need to exclude that from the + # computed blocks and num_new_computed_tokens, so that + # new blocks get allocated that we can copy kv values to + if non_match_found: + computed_blocks = (computed_blocks[0][:non_match_idx],) + num_new_computed_tokens = len(computed_blocks[0]) * self.block_size + + if envs.VLLM_V1_SPANS_DEBUG: print( "[SPANS -> kv_cache_manager] here's the blocks hashed in " \ "this request:", - [str(abs(b.hash_value))[:4] for b in request.block_hashes]) + [str(b)[-4:] for b in request.block_hashes]) kvcache_contents = [ - str(abs(b.block_hash.block_hash.hash_value))[:4] + str(b.block_hash)[-4:] if b.block_hash else None for b in self.block_pool.blocks if b._block_hash ] @@ -212,35 +248,17 @@ def get_computed_blocks(self, "[SPANS -> kv_cache_manager] here's the number of blocks " \ "that hit the cache:", [ - str(abs(b.block_hash.block_hash.hash_value))[:4] + str(b.block_hash)[-4:] if b.block_hash else None for b in computed_blocks[0] ]) - - blocks_to_reposition = [] - if envs.VLLM_V1_SPANS_ENABLED: - # Spans does yet not support hybrid models - assert len(computed_blocks) == 1 - for i, b in enumerate(computed_blocks[0]): - prompt_pos = i * 16 - kvc_pos = b.position - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] checking block " \ - f"{b.block_id} with prompot pos {prompt_pos} " \ - f"and kv pos {kvc_pos}" - ) - assert isinstance(kvc_pos, int) - if kvc_pos != prompt_pos: - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} "\ - f"to prompt pos: {prompt_pos} repositioning needed" - ) - - blocks_to_reposition.append( - BlockRepositionRequest(b.block_id, kvc_pos, - prompt_pos)) - b.position = int(prompt_pos) + # for block duplication + num_repo = len([r for r in repo_reqs + if r.prompt_pos != r.cached_pos]) + num_copy = len(repo_reqs) - num_repo + print( + "[SPANS -> kv_cache_manager] here's the number of blocks", + f"total: {len(repo_reqs)} to reposition: {num_repo},", + f"to copy: {num_copy}") if self.log_stats: assert self.prefix_cache_stats is not None @@ -248,7 +266,7 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks, blocks_to_reposition),\ + return KVCacheBlocks(computed_blocks, repo_reqs),\ num_new_computed_tokens def allocate_slots( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dbcb7ed39f..c525323481 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -389,6 +389,11 @@ def schedule(self) -> SchedulerOutput: len(new_computed_blocks.blocks_to_reposition) > 0: blocks_to_reposition.extend( new_computed_blocks.blocks_to_reposition) + + # TODO find something smarter to do than this + token_budget += \ + len(new_computed_blocks.blocks_to_reposition) \ + * self.block_size # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -545,8 +550,10 @@ def schedule(self) -> SchedulerOutput: self.waiting.prepend_requests(skipped_waiting_requests) # Check if the scheduling constraints are satisfied. + # TODO make this smarter for spans total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) - assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + \ + len(blocks_to_reposition) * self.block_size assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8159349e46..77a0ecd9d4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -269,12 +269,13 @@ def find_longest_cache_hit( if dcp_world_size > 1: block_size *= dcp_world_size max_num_blocks = max_length // block_size - for block_hash in itertools.islice(block_hashes, max_num_blocks): + for pidx, block_hash in enumerate( + itertools.islice(block_hashes, max_num_blocks)): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids, position=pidx): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 420df37f14..0926310cb6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -85,6 +85,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.core.kv_cache_manager import BlockRepositionRequest +from vllm.v1.core.sched.output import NewRequestData from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -378,6 +380,10 @@ def __init__( device="cpu", pin_memory=self.pin_memory) + # self.reposition_request_cache: dict[str, BlockRepositionRequest] = {} + self.reposition_request_cache: \ + defaultdict[str, list[BlockRepositionRequest]] = defaultdict(list) + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, @@ -1587,9 +1593,100 @@ def _pool( pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) + def _copy_blocks(self, + blocks_to_copy: list[BlockRepositionRequest], + newreqs_by_id: dict[str, NewRequestData]) -> None: + from_ids = [] + to_ids = [] + for req in blocks_to_copy: + from_ids.append(req.cached_blockid) + # find out the block to copy to + # 1. find the relevant new request + # 2. then find the block at position prompt_blockpos + to_ids.append(newreqs_by_id[req.prompt_reqid]\ + .block_ids[0][req.prompt_blockpos]) + # perform copies + args = dict(dtype=torch.long, device=self.kv_caches[0].device) + for i in range(len(self.kv_caches)): + self.kv_caches[i][:, torch.tensor(to_ids, **args), ...] = \ + self.kv_caches[i][:, torch.tensor(from_ids, **args), ...] + + def _custom_cache_manipulations(self, + scheduler_output: "SchedulerOutput") \ + -> None: + + # only allow as many reposition requests + # as a request has tokens scheduled + for req in scheduler_output.blocks_to_reposition: + self.reposition_request_cache[req.prompt_reqid].append(req) + # 1. find out how many repo requests + # we are scheduled to make + scheduled_reposition_reqs = [] + for rid, ntoks in scheduler_output.num_scheduled_tokens.items(): + cached_rreqs = self.reposition_request_cache[rid] + n_cached_rreqs = len(cached_rreqs) + if n_cached_rreqs > 0: + # take as many as can be scheduled + nsched_rreqs = min( + ntoks // self.cache_config.block_size, + n_cached_rreqs) + scheduled_reposition_reqs.extend( + cached_rreqs[:nsched_rreqs]) + if nsched_rreqs < n_cached_rreqs: + self.reposition_request_cache[rid] =\ + cached_rreqs[nsched_rreqs:] + else: + self.reposition_request_cache[rid] = [] + # and then we adjust the rest of this function so it only uses + # the scheduled repo requests + + # 0. sort requests + blocks_to_copy = [] + blocks_to_repo = [] + [(blocks_to_copy if req.cached_pos == req.prompt_pos + else blocks_to_repo).append(req) + for req in scheduled_reposition_reqs] + newreqs_by_id = {r.req_id: r for r + in + scheduler_output.scheduled_new_reqs + \ + [self.requests[rid] for rid in \ + scheduler_output.scheduled_cached_reqs.req_ids]} + # 1. perform copies + self._copy_blocks(blocks_to_copy, newreqs_by_id) + # 2. do repositioning + self._perform_repositioning(blocks_to_repo, newreqs_by_id) + # 3. adjust relevant counters + # 3.1 num_scheduled_tokens + req_ntokens_to_skip = defaultdict(lambda: 0) + for rreq in scheduled_reposition_reqs: + req_ntokens_to_skip[rreq.prompt_reqid] += \ + self.cache_config.block_size + # 16 + for reqid, ntoks in req_ntokens_to_skip.items(): + scheduler_output.num_scheduled_tokens[reqid] -= ntoks + # 3.2 total_num_scheduled_tokens + scheduler_output.total_num_scheduled_tokens -= \ + len(scheduled_reposition_reqs) \ + * self.cache_config.block_size + # * 16 + # 3.3 scheduled_new_reqs (num_computed_tokens) + for i in range(len(scheduler_output.scheduled_new_reqs)): # TODO also do this for cached requests + sr = scheduler_output.scheduled_new_reqs[i] + sr.num_computed_tokens += req_ntokens_to_skip[sr.req_id] + # for rid in scheduler_output.scheduled_cached_reqs.req_ids: + # req = self.requests[rid] + scc = scheduler_output.scheduled_cached_reqs + for i in range(len(scc.req_ids)): + rid = scc.req_ids[i] + scc.num_computed_tokens[i] += req_ntokens_to_skip[rid] + # NOTE maybe PP is broken here because + # we don't manipulate new_token_ids + # in the cached request data + def _perform_repositioning(self, - scheduler_output: "SchedulerOutput") -> None: + blocks_to_reposition: list[BlockRepositionRequest], + newreqs_by_id: dict[str, NewRequestData]) -> None: """ Repositions KV cache blocks based on the scheduler's instructions. @@ -1602,19 +1699,33 @@ def _perform_repositioning(self, scheduler_output: The output from the scheduler containing blocks to reposition. """ - blocks_to_reposition = scheduler_output.blocks_to_reposition if envs.VLLM_V1_SPANS_DEBUG: ts_repo = time.time() repo_count = len(blocks_to_reposition) - if len(blocks_to_reposition) < 600: - self._repositionings_handler(blocks_to_reposition) + # figure out destination block IDs + dest_ids = [] + valid_blocks_to_reposition = [] + for req in blocks_to_reposition: + try: + dest_ids.append(newreqs_by_id[req.prompt_reqid]\ + .block_ids[0][req.prompt_blockpos]) + valid_blocks_to_reposition.append(req) + except IndexError as e: + # breakpoint() + print('INDEX_ERROR could not run reposition request:', req, e) + + if len(valid_blocks_to_reposition) < 600: + self._repositionings_handler(valid_blocks_to_reposition, + dest_ids) else: bs = 400 - for i in range(len(blocks_to_reposition) // bs): + for i in range(len(valid_blocks_to_reposition) // bs): j = bs if i + bs * 2 < len( - blocks_to_reposition) else i + bs * 2 - repo_batch = blocks_to_reposition[i:j] - self._repositionings_handler(repo_batch) + valid_blocks_to_reposition) else i + bs * 2 + repo_batch = valid_blocks_to_reposition[i:j] + dest_batch = dest_ids[i:j] + self._repositionings_handler(repo_batch, + dest_batch) if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: torch.cuda.synchronize() t_repo = time.time() - ts_repo @@ -1622,7 +1733,8 @@ def _perform_repositioning(self, f' speed: {repo_count/t_repo:.2f} (blocks/s)') @torch.inference_mode() - def _repositionings_handler(self, blocks_to_reposition): + def _repositionings_handler(self, blocks_to_reposition, + destination_block_ids): num_repos = len(blocks_to_reposition) if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: print( @@ -1631,7 +1743,7 @@ def _repositionings_handler(self, blocks_to_reposition): ) if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION and num_repos > 0: kvc_positions = torch.tensor( - [d.kvc_pos for d in blocks_to_reposition], + [d.cached_pos for d in blocks_to_reposition], dtype=torch.long, device=self.kv_caches[0].device).unsqueeze(-1) prt_positions = torch.tensor( @@ -1639,7 +1751,11 @@ def _repositionings_handler(self, blocks_to_reposition): dtype=torch.long, device=self.kv_caches[0].device).unsqueeze(-1) block_ids = torch.tensor( - [d.block_id for d in blocks_to_reposition], + [d.cached_blockid for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device) + dest_block_ids = torch.tensor( + destination_block_ids, dtype=torch.long, device=self.kv_caches[0].device) @@ -1681,8 +1797,10 @@ def _repositionings_handler(self, blocks_to_reposition): invert_rotation_angle=True) k_vectors_tmp, _ = self.rotate.forward_native( pos_repos, k_vectors_tmp) - self.kv_caches[i][0, block_ids, ...] = \ + self.kv_caches[i][0, dest_block_ids, ...] = \ k_vectors_tmp.to(DEF_PRECISION) + self.kv_caches[i][1, dest_block_ids, ...] = \ + self.kv_caches[i][1, block_ids] else: k_vectors_tmp, _ = self.rotate.forward_native( pos_depos, @@ -1692,8 +1810,10 @@ def _repositionings_handler(self, blocks_to_reposition): k_vectors_tmp, _ = self.rotate.forward_native( pos_repos, k_vectors_tmp) for i in range(len(self.kv_caches)): - self.kv_caches[i][0, block_ids, ...] = \ + self.kv_caches[i][0, dest_block_ids, ...] = \ k_vectors_tmp[i].to(DEF_PRECISION) + self.kv_caches[i][1, dest_block_ids, ...] = \ + self.kv_caches[i][1, block_ids] def _preprocess( self, @@ -1960,7 +2080,7 @@ def execute_model( with record_function_or_nullcontext("Preprocess"): # handle repositioning requests - self._perform_repositioning(scheduler_output) + self._custom_cache_manipulations(scheduler_output) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: From 7a4e46b8c30f6e9dec9a463868502c93500e86dc Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Thu, 18 Sep 2025 18:02:12 -0400 Subject: [PATCH 3/9] bug fix (block duplication seems to work) Signed-off-by: Nathan Ordonez --- vllm/v1/core/kv_cache_utils.py | 1 + vllm/v1/worker/gpu_model_runner.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 077b09f4ab..a9b7dabe5a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -372,6 +372,7 @@ def append_n(self, blocks: list[KVCacheBlock]) -> None: """ if len(blocks) == 0: return + blocks = list({b.block_id: b for b in blocks}.values()) self.num_free_blocks += len(blocks) last_block = self.fake_free_list_tail.prev_free_block diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0926310cb6..e2e5c52638 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1646,6 +1646,9 @@ def _custom_cache_manipulations(self, [(blocks_to_copy if req.cached_pos == req.prompt_pos else blocks_to_repo).append(req) for req in scheduled_reposition_reqs] + if envs.VLLM_V1_SPANS_DISABLE_REPOSITION: + blocks_to_copy.extend(blocks_to_repo) + blocks_to_repo = [] newreqs_by_id = {r.req_id: r for r in scheduler_output.scheduled_new_reqs + \ @@ -1719,9 +1722,8 @@ def _perform_repositioning(self, dest_ids) else: bs = 400 - for i in range(len(valid_blocks_to_reposition) // bs): - j = bs if i + bs * 2 < len( - valid_blocks_to_reposition) else i + bs * 2 + for i in range(0, len(valid_blocks_to_reposition), bs): + j = i+bs repo_batch = valid_blocks_to_reposition[i:j] dest_batch = dest_ids[i:j] self._repositionings_handler(repo_batch, @@ -1741,7 +1743,7 @@ def _repositionings_handler(self, blocks_to_reposition, f'[SPANS -> gpu_model_runner] ' \ f'reposition block count: {num_repos}' ) - if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION and num_repos > 0: + if num_repos > 0: kvc_positions = torch.tensor( [d.cached_pos for d in blocks_to_reposition], dtype=torch.long, From 92812efe1908e8c8767c29c0aa3910f8d3a16fe5 Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Tue, 23 Sep 2025 04:35:50 -0400 Subject: [PATCH 4/9] bugfix repositioning Signed-off-by: Nathan Ordonez --- vllm/v1/worker/gpu_model_runner.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e2e5c52638..ae9d03cbc4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1614,7 +1614,6 @@ def _copy_blocks(self, def _custom_cache_manipulations(self, scheduler_output: "SchedulerOutput") \ -> None: - # only allow as many reposition requests # as a request has tokens scheduled for req in scheduler_output.blocks_to_reposition: @@ -1791,6 +1790,7 @@ def _repositionings_handler(self, blocks_to_reposition, break assert pos_depos.shape[0] == concerned_vectors[0].shape[0] + if num_repos > 100: for i, k_vectors in enumerate(concerned_vectors): k_vectors_tmp, _ = self.rotate.forward_native( @@ -1798,19 +1798,23 @@ def _repositionings_handler(self, blocks_to_reposition, k_vectors.to(PRECISION), invert_rotation_angle=True) k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos, k_vectors_tmp) + pos_repos, k_vectors_tmp) self.kv_caches[i][0, dest_block_ids, ...] = \ k_vectors_tmp.to(DEF_PRECISION) self.kv_caches[i][1, dest_block_ids, ...] = \ self.kv_caches[i][1, block_ids] else: + nlays = len(concerned_vectors) + kvecs = torch.cat(concerned_vectors, dim=0).to(PRECISION) k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos, - torch.cat([k.unsqueeze(0) for k in concerned_vectors], - dim=0).to(PRECISION), + pos_depos.repeat(nlays, 1), + kvecs, invert_rotation_angle=True) k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos, k_vectors_tmp) + pos_repos.repeat(nlays, 1), + k_vectors_tmp) + k_vectors_tmp = k_vectors_tmp.reshape(nlays, + *concerned_vectors[0].shape) for i in range(len(self.kv_caches)): self.kv_caches[i][0, dest_block_ids, ...] = \ k_vectors_tmp[i].to(DEF_PRECISION) From 3f95cd12e5b683affb3c4f65949d0f327790ffa9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 19 Sep 2025 07:18:36 +0200 Subject: [PATCH 5/9] Initial support for span semantics (#88) An initial implementation of span semantics in vLLM. Please note that this has a known bug dealing with concurrent sequences that re-use the same span in different locations. We are working on a solution for this, but in the meantime accuracy may be negatively affected. n/a n/a ---
Essential Elements of an Effective PR Description Checklist - [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results - [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. - [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
--------- Signed-off-by: Thomas Parnell Signed-off-by: Nathan Ordonez Co-authored-by: Nathan Ordonez Co-authored-by: Nathan Ordonez --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c525323481..591cc4e843 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -390,7 +390,7 @@ def schedule(self) -> SchedulerOutput: blocks_to_reposition.extend( new_computed_blocks.blocks_to_reposition) - # TODO find something smarter to do than this + # TODO (Nathan) find something smarter to do than this token_budget += \ len(new_computed_blocks.blocks_to_reposition) \ * self.block_size From 4f5c00f0bda77e247a81eab07aadf9b148eec016 Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Wed, 24 Sep 2025 06:04:51 -0400 Subject: [PATCH 6/9] bugfix, benefits now show up (and including benchmark that shows said benefits) Signed-off-by: Nathan Ordonez --- .../spans/spans_benchmark.py | 178 ++++++++++++++++++ vllm/v1/core/single_type_kv_cache_manager.py | 3 +- 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 examples/offline_inference/spans/spans_benchmark.py diff --git a/examples/offline_inference/spans/spans_benchmark.py b/examples/offline_inference/spans/spans_benchmark.py new file mode 100644 index 0000000000..8b8eb7a561 --- /dev/null +++ b/examples/offline_inference/spans/spans_benchmark.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import time +import random + +# necessary for spans to work +os.environ["VLLM_USE_V1"] = "1" +# to ensure deterministic behaviour +os.environ["TOKENIZERS_PARALLELISM"] = "False" + +# in case you need it +os.environ['VLLM_ATTENTION_BACKEND'] = "TRITON_ATTN_VLLM_V1" +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = '0' + +# standard imports +from vllm import LLM, SamplingParams +from vllm.inputs import TokensPrompt + + +# helper functions +def pad(toklist): + padtok = int(os.environ.get("VLLM_V1_SPANS_TOKEN_PAD", None)) + return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:] + + +def avg(list_of_numbers): + return sum(list_of_numbers) / max(len(list_of_numbers), 1) + + +def wrap(prompt): + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + +def initialize_vllm(model, + temp=0.6, + logprobs=None, + max_toks=131072, + max_generated_toks=1): + # boot up vLLM + samp_params_preload = SamplingParams(temperature=temp, max_tokens=1) + samp_params_generate = SamplingParams(temperature=temp, + max_tokens=max_generated_toks, + logprobs=logprobs) + llm = LLM( + model=model, + gpu_memory_utilization=0.9, + enforce_eager=True, # <- so it boots faster + block_size=16, + max_model_len=max_toks, + max_num_seqs=4, + ) + tok = llm.get_tokenizer() + tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x)) + return samp_params_preload, samp_params_generate, tok_fun, llm + + +def main(): + model_names = [ + "ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention + "ldsjmdy/Tulu3-RAG", # <- baseline + ] + model_name = model_names[0] + + # tokens that need to be set to perform block-attention + PAD_TOK = 27 # <- "<" + SPAN_TOK = 10 # <- "+" + SPAN_RECOMP_TOK = 31 # <- "@" + + # vLLM-specific env vars + + # enables block attention + # -> when this line is not commented, we expect a speedup + # in the execution of the last two .generate calls + os.environ['VLLM_V1_SPANS_ENABLED'] = 'True' + + # the token that tells vLLM "this is the beginning of a span" + os.environ['VLLM_V1_SPANS_TOKEN_PLUS'] = str(SPAN_TOK) + + # token that tells vLLM: + # "from here on, recompute KV vectors if any previous tokens differ" + os.environ['VLLM_V1_SPANS_TOKEN_CROSS'] = str(SPAN_RECOMP_TOK) + + # will print every step of the span process if set to true + # os.environ['VLLM_V1_SPANS_DEBUG'] = 'True' + + # will disable the adjustment of positional encodings when a KV cache + # block is loaded to a different position than it was stored + # -> when this line is not commented, + # spans overlap in their positional encodings + os.environ['VLLM_V1_SPANS_DISABLE_REPOSITION'] = 'True' + + # general env vars + + # our helper function uses this token to pad spans + os.environ['VLLM_V1_SPANS_TOKEN_PAD'] = str(PAD_TOK) + + # now we instantiate the model + samp_params_preload, samp_params_generate, tok, llm = initialize_vllm( + model_name, max_generated_toks=1) + # model_name, max_generated_toks=1, max_toks=2048) + + # components of the prompt template + prefix = pad( + [SPAN_RECOMP_TOK] + tok("<|system|>\nYou are an intelligent AI assistant. " \ + "Please answer questions based on the user's instructions. " \ + "Below are some reference documents that may help you in " \ + "answering the user's question." + )) + midfx = [SPAN_RECOMP_TOK] + tok( + "<|user|>\nPlease write a high-quality answer for the " \ + "given question using only the provided search documents " \ + "(some of which might be irrelevant).\nQuestion: " + ) + postfx = tok('''\n<|assistant|>\n''') + + print("---->", postfx) + + times = [] + for ndocs in [1, 2, 4, 8]: + for dlen in [512, 1024, 2048, 4096, 8192]: + print(f" DOCLENGTH {dlen} NUMDOCS {ndocs}") + + doc_toks = tok( + "Sequence Transduction Models and Template-Assisted Selective Epitaxy") + docs = [pad([SPAN_TOK] + + random.choices(doc_toks, k=dlen)) + for _ in range(ndocs)] + + # user query + query = midfx + tok( + "Tell me which one concerns deep learning. " \ + "Indicate your answer with a number in brackets." + ) + postfx + + for i in range(3): + print(f" ITERATION {i}") + + # preload documents + ts_pre = time.time() + llm.generate( + [wrap(d) for d in docs] + [wrap(prefix)], + sampling_params=samp_params_preload, use_tqdm=False) + te_pre = time.time() - ts_pre + + ts_gen = time.time() + + # this now will load prefix, doc_a, doc_b, + # from the KV cache regardless of the order + random.shuffle(docs) + llm.generate(wrap(prefix + \ + sum(docs, []) + \ + query), + sampling_params=samp_params_generate, use_tqdm=False) + + # this should also run faster: + random.shuffle(docs) + llm.generate(wrap(prefix + \ + sum(docs, []) + \ + query), + sampling_params=samp_params_generate, use_tqdm=False) + + te_gen = time.time() - ts_gen + + print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)") + times.append(dict( + preload_time=te_pre, + gen_time=te_gen, + it=i, + doc_len=dlen, + num_docs=ndocs, + )) + + +if __name__ == '__main__': + main() diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 77a0ecd9d4..6d14619e55 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -275,7 +275,8 @@ def find_longest_cache_hit( # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids, position=pidx): + block_hash, kv_cache_group_ids, + position=pidx*kv_cache_spec.block_size): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: From 3998c6f725da17054c13120403ba1324af0582dd Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Wed, 24 Sep 2025 06:09:06 -0400 Subject: [PATCH 7/9] development folder Signed-off-by: Nathan Ordonez --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index b1df673e83..6e96289d87 100644 --- a/.gitignore +++ b/.gitignore @@ -218,3 +218,5 @@ csrc/moe/marlin_moe_wna16/kernel_* # Ignore ep_kernels_workspace folder ep_kernels_workspace/ + +dev/ \ No newline at end of file From cdae9f9c3ef633541242d56772bec3368846a4b6 Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Wed, 24 Sep 2025 10:25:04 -0400 Subject: [PATCH 8/9] bugfix Signed-off-by: Nathan Ordonez --- vllm/v1/core/kv_cache_manager.py | 59 ++++++++++++++++---------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0e63cb0f9a..5c902ae1b5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -193,37 +193,38 @@ def get_computed_blocks(self, self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) - # now we check how many of those computed blocks have incorrect or are - # after an incorrect position match - # our own positions are clear, now we need to compare that to cached - # positions repo_reqs = [] - non_match_idx = -1 - non_match_found = False - for i, block in enumerate(computed_blocks[0]): - if block.is_null: # null blocks don't have meaningful position - continue - prompt_pos = self.block_size * i - cached_pos = block.position - # find first block id where pos didn't match - if prompt_pos != cached_pos and not non_match_found: - non_match_found = True - non_match_idx = i - # record from then on and after, repo requests + if envs.VLLM_V1_SPANS_ENABLED: + # now we check how many of those computed blocks have incorrect or are + # after an incorrect position match + # our own positions are clear, now we need to compare that to cached + # positions + non_match_idx = -1 + non_match_found = False + for i, block in enumerate(computed_blocks[0]): + if block.is_null: # null blocks don't have meaningful position + continue + prompt_pos = self.block_size * i + cached_pos = block.position + # find first block id where pos didn't match + if prompt_pos != cached_pos and not non_match_found: + non_match_found = True + non_match_idx = i + # record from then on and after, repo requests + if non_match_found: + repo_reqs.append( + BlockRepositionRequest( + prompt_pos, + cached_pos, + block.block_id, + i, + request.request_id)) + # if any repo is needed, we need to exclude that from the + # computed blocks and num_new_computed_tokens, so that + # new blocks get allocated that we can copy kv values to if non_match_found: - repo_reqs.append( - BlockRepositionRequest( - prompt_pos, - cached_pos, - block.block_id, - i, - request.request_id)) - # if any repo is needed, we need to exclude that from the - # computed blocks and num_new_computed_tokens, so that - # new blocks get allocated that we can copy kv values to - if non_match_found: - computed_blocks = (computed_blocks[0][:non_match_idx],) - num_new_computed_tokens = len(computed_blocks[0]) * self.block_size + computed_blocks = (computed_blocks[0][:non_match_idx],) + num_new_computed_tokens = len(computed_blocks[0]) * self.block_size if envs.VLLM_V1_SPANS_DEBUG: From 116b4572282dfb5c675b288d16c5bfd59e76b137 Mon Sep 17 00:00:00 2001 From: Nathan Ordonez Date: Mon, 29 Sep 2025 13:47:01 -0400 Subject: [PATCH 9/9] speed optimizations (from 6x to 1.3x overhead) --- vllm/v1/core/block_pool.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 908e1852f0..fc753d55c6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -77,9 +77,8 @@ def _closest_cache_hit( self, cached_blocks: dict[int, KVCacheBlock], position: int, ) -> dict[int, KVCacheBlock]: - return sorted( - list(cached_blocks.values()), - key=lambda x: abs(x.position - position))[0] + return min(list(cached_blocks.values()), + key=lambda x: abs(x.position - position)) def get_cached_block( self, block_hash: BlockHash, @@ -206,11 +205,13 @@ def _set_block_positions(self, new_full_blocks: list[KVCacheBlock], debug logging that prints each block's tokens, to help debug span-related workflows. """ + dbg = envs.VLLM_V1_SPANS_DEBUG pos = 0 + nfb_ids = {b.block_id for b in new_full_blocks} for blk in blocks: - if blk in new_full_blocks: + if blk.block_id in nfb_ids: blk.position = pos - if envs.VLLM_V1_SPANS_DEBUG: + if dbg: # this prints the tokens assigned to a new block # in the KV cache blk_tks = request.all_token_ids[pos:pos + 16]