Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/blossom-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ jobs:
"hijkzzz",
"hlu1",
"hnover-nv",
"Hudayday",
"HuiGao-NV",
"hvagadia",
"hypdeb",
Expand Down
45 changes: 2 additions & 43 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,7 @@ class KVCacheBlock : public std::enable_shared_from_this<KVCacheBlock>
static BlockPtr createPlaceholder(IdType blockId, SizeType32 windowSize);

void detachDescendantsFromLookupTree();
//! \brief Detach all placeholder blocks in the previous-block chain from the lookup tree.
//! \details Walks upward via getPrevBlock() and calls detachFromLookupNode() on each
//! block that is a placeholder. Stops at the root (kCachedBlocksRootId).
void detachPreviousPlaceholdersFromLookupTree() const;

void freeBlockAndAllDescendants();

//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
Expand Down Expand Up @@ -630,7 +627,7 @@ class GenerationRequest
TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen,
"currentPrepopulatedPromptLen must be updated non-increasingly due to the "
"assumption that smaller window sizes have shorter or equal"
"currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks.");
" currentPrepopulatedPromptLen during multi-window batch allocation.");
mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen;
}

Expand Down Expand Up @@ -785,11 +782,6 @@ class WindowBlockManager

void startScheduling();

//! \brief Assign blocks for new sequence
//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, bool isEnableBlockReuse);

//! \brief Per-request block allocation statistics from batch addSequence.
struct BatchSeqStats
{
Expand Down Expand Up @@ -1158,16 +1150,6 @@ class WindowBlockManager
//! \brief Add single block to all beams of sequence.
void addBlockToAllBeams(BlockPtr const& block, GenerationRequest& sequence);

//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
//! \param blockKeys Key of each block.
//! \param sequence Sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType32 loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 inputLength,
SizeType32 numContextBlocks, GenerationRequest& sequence,
std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "",
bool isEnableBlockReuse = false);

//! \brief Phase 1: Walk radix tree and claim matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
//! Uses \p tracker to coordinate partial-match ownership across requests in
Expand Down Expand Up @@ -1382,10 +1364,6 @@ class BlockManager

void allocatePools(bool useUvm);

//! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen)
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize, bool isEnableBlockReuse);

//! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch.
[[nodiscard]] std::vector<WindowBlockManager::BatchSeqStats> addSequenceBatch(
std::vector<GenerationRequest*> const& sequences, std::vector<SizeType32> const& inputLengths,
Expand Down Expand Up @@ -1895,16 +1873,6 @@ class BaseKVCacheManager
/// LlmRequest::getNumTokens.
[[nodiscard]] virtual SizeType32 getTokenCount(LlmRequest::RequestIdType requestId) const = 0;

/// @brief Add new request to the KV cache manager.
/// @param inputLength Input length for which KV cache need to be allocated.
/// @param beamWidth Beam width for which KV cache need to be allocated.
/// @param llmRequest Optional request to use for KV cache lookup.
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt)
= 0;

//! \brief Batch add sequences with two-phase claim-then-onboard strategy.
//! \details For each attention window, when block reuse is enabled, Phase 1 claims all matching
//! blocks across all requests (protecting them from eviction via PartialClaimTracker),
Expand Down Expand Up @@ -2276,15 +2244,6 @@ class KVCacheManager : public BaseKVCacheManager
//! the placeholder block). It should be called before every forward step, after adding new tokens.
void copyLinearAttentionBlock(LlmRequest const& llmRequest);

/// @brief Add new request to the KV cache manager.
/// @param inputLength Input length for which KV cache need to be allocated.
/// @param beamWidth Beam width for which KV cache need to be allocated.
/// @param llmRequest Optional request to use for KV cache lookup.
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;

void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests) override;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -63,12 +63,13 @@ class KVCacheTransferManager
//! \brief Synchronize internal streams with bufferManager stream.
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
//! any block copies. This method must be called before the first call to
//! KVCacheManager::addSequenceBatch in every step.
void syncWithBufferManager();

//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
//! must be called after last call to KVCacheManager::addSequence in every step.
//! must be called after the last call to KVCacheManager::addSequenceBatch in every step.
void syncTransfers();

//! \brief Get transfer stats accumulated since last call, and reset the counters.
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -2002,7 +2002,7 @@ class GenericLlmRequest
// getRemainingBlocksToCompletion) so that the micro batch scheduler
// can account for cached tokens when computing the token budget.
// Marked mutable because it is a cache/estimate set during const
// capacity-scheduler queries. Reset to 0 after addSequence sets
// capacity-scheduler queries. Reset to 0 after addSequenceBatch sets
// the authoritative mPrepopulatedPromptLen and advances context position.
mutable SizeType32 mEstimatedReusableTokens{0};

Expand Down
9 changes: 6 additions & 3 deletions cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,6 +19,8 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"

#include <functional>

void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager& kvCacheManager,
RequestVector& contextRequests, RequestVector const& generationRequests, runtime::ModelConfig const& modelConfig,
OptionalRef<BaseKVCacheManager> crossKvCacheManager) const
Expand All @@ -38,7 +40,7 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
auto draftLength = llmReq->getNumDraftTokens();

// Allocate/Reuse KV cache
kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq);
kvCacheManager.addSequenceBatch({{{requestId, promptLen, reqBeamWidth}}}, {std::ref(*llmReq)});

// EagleNet will increment kv cache up to maxPathLen to account for accepted tokens.
// Then up to maxDecodingDraftTokens will be used to generate next draft tokens.
Expand All @@ -59,7 +61,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager

if (crossKvCacheManager)
{
crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq);
crossKvCacheManager->addSequenceBatch(
{{{requestId, llmReq->getEncoderOutputLen(), reqBeamWidth}}}, {std::ref(*llmReq)});
}
}
}
Expand Down
Loading
Loading