diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 12b2b5405d24..f6f29edce5a1 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -133,6 +133,7 @@ jobs: "hijkzzz", "hlu1", "hnover-nv", + "Hudayday", "HuiGao-NV", "hvagadia", "hypdeb", diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index fb688e75cfda..22a6a4d1b456 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -393,10 +393,7 @@ class KVCacheBlock : public std::enable_shared_from_this static BlockPtr createPlaceholder(IdType blockId, SizeType32 windowSize); void detachDescendantsFromLookupTree(); - //! \brief Detach all placeholder blocks in the previous-block chain from the lookup tree. - //! \details Walks upward via getPrevBlock() and calls detachFromLookupNode() on each - //! block that is a placeholder. Stops at the root (kCachedBlocksRootId). - void detachPreviousPlaceholdersFromLookupTree() const; + void freeBlockAndAllDescendants(); //! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of @@ -630,7 +627,7 @@ class GenerationRequest TLLM_CHECK_WITH_INFO(currentPrepopulatedPromptLen <= mCurrentPrepopulatedPromptLen, "currentPrepopulatedPromptLen must be updated non-increasingly due to the " "assumption that smaller window sizes have shorter or equal" - "currentPrepopulatedPromptLen in WindowSizeManager::loadOrAllocateBlocks."); + " currentPrepopulatedPromptLen during multi-window batch allocation."); mCurrentPrepopulatedPromptLen = currentPrepopulatedPromptLen; } @@ -785,11 +782,6 @@ class WindowBlockManager void startScheduling(); - //! \brief Assign blocks for new sequence - //! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen) - [[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength, - SizeType32 numContextBlocks, LlmRequest& llmRequest, bool isEnableBlockReuse); - //! \brief Per-request block allocation statistics from batch addSequence. struct BatchSeqStats { @@ -1158,16 +1150,6 @@ class WindowBlockManager //! \brief Add single block to all beams of sequence. void addBlockToAllBeams(BlockPtr const& block, GenerationRequest& sequence); - //! \brief Try to load blocks from cache. Allocate new blocks if necessary. - //! \param blockKeys Key of each block. - //! \param sequence Sequence to which blocks are assigned. - //! \return Number of matched tokens from loaded blocks. - SizeType32 loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 inputLength, - SizeType32 numContextBlocks, GenerationRequest& sequence, - std::vector const& perBlockRetentions, - executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "", - bool isEnableBlockReuse = false); - //! \brief Phase 1: Walk radix tree and claim matching blocks. //! \details Caller must hold mCachedBlocksRootMutex. //! Uses \p tracker to coordinate partial-match ownership across requests in @@ -1382,10 +1364,6 @@ class BlockManager void allocatePools(bool useUvm); - //! \return The number of tokens that were matched/prepopulated from cache (prepopulatedPromptLen) - [[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength, - SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize, bool isEnableBlockReuse); - //! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch. [[nodiscard]] std::vector addSequenceBatch( std::vector const& sequences, std::vector const& inputLengths, @@ -1895,16 +1873,6 @@ class BaseKVCacheManager /// LlmRequest::getNumTokens. [[nodiscard]] virtual SizeType32 getTokenCount(LlmRequest::RequestIdType requestId) const = 0; - /// @brief Add new request to the KV cache manager. - /// @param inputLength Input length for which KV cache need to be allocated. - /// @param beamWidth Beam width for which KV cache need to be allocated. - /// @param llmRequest Optional request to use for KV cache lookup. - /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for - /// inputLength - 1 tokens and populate prepopulatedPromptLen. - virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt) - = 0; - //! \brief Batch add sequences with two-phase claim-then-onboard strategy. //! \details For each attention window, when block reuse is enabled, Phase 1 claims all matching //! blocks across all requests (protecting them from eviction via PartialClaimTracker), @@ -2276,15 +2244,6 @@ class KVCacheManager : public BaseKVCacheManager //! the placeholder block). It should be called before every forward step, after adding new tokens. void copyLinearAttentionBlock(LlmRequest const& llmRequest); - /// @brief Add new request to the KV cache manager. - /// @param inputLength Input length for which KV cache need to be allocated. - /// @param beamWidth Beam width for which KV cache need to be allocated. - /// @param llmRequest Optional request to use for KV cache lookup. - /// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for - /// inputLength - 1 tokens and populate prepopulatedPromptLen. - void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - OptionalRef llmRequest = std::nullopt) override; - void addSequenceBatch( std::vector> const& requestInfos, std::vector> const& llmRequests) override; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 8c40a46045d2..57f8f928a8be 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,12 +63,13 @@ class KVCacheTransferManager //! \brief Synchronize internal streams with bufferManager stream. //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the //! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing - //! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. + //! any block copies. This method must be called before the first call to + //! KVCacheManager::addSequenceBatch in every step. void syncWithBufferManager(); //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode //! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method - //! must be called after last call to KVCacheManager::addSequence in every step. + //! must be called after the last call to KVCacheManager::addSequenceBatch in every step. void syncTransfers(); //! \brief Get transfer stats accumulated since last call, and reset the counters. diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index c1ecffff026b..dae828018daa 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2002,7 +2002,7 @@ class GenericLlmRequest // getRemainingBlocksToCompletion) so that the micro batch scheduler // can account for cached tokens when computing the token budget. // Marked mutable because it is a cache/estimate set during const - // capacity-scheduler queries. Reset to 0 after addSequence sets + // capacity-scheduler queries. Reset to 0 after addSequenceBatch sets // the authoritative mPrepopulatedPromptLen and advances context position. mutable SizeType32 mEstimatedReusableTokens{0}; diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index 211abe781861..dbb211de1a20 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +19,8 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include + void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager& kvCacheManager, RequestVector& contextRequests, RequestVector const& generationRequests, runtime::ModelConfig const& modelConfig, OptionalRef crossKvCacheManager) const @@ -38,7 +40,7 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager auto draftLength = llmReq->getNumDraftTokens(); // Allocate/Reuse KV cache - kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq); + kvCacheManager.addSequenceBatch({{{requestId, promptLen, reqBeamWidth}}}, {std::ref(*llmReq)}); // EagleNet will increment kv cache up to maxPathLen to account for accepted tokens. // Then up to maxDecodingDraftTokens will be used to generate next draft tokens. @@ -59,7 +61,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager if (crossKvCacheManager) { - crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq); + crossKvCacheManager->addSequenceBatch( + {{{requestId, llmReq->getEncoderOutputLen(), reqBeamWidth}}}, {std::ref(*llmReq)}); } } } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b00cf247ea69..90a567841f20 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -494,34 +494,9 @@ void KVCacheBlock::detachDescendantsFromLookupTree() } } -void KVCacheBlock::detachPreviousPlaceholdersFromLookupTree() const -{ - BlockPtr current = getPrevBlock(); - while (current != nullptr && current->getBlockId() != KVCacheBlock::kCachedBlocksRootId) - { - if (!current->isPlaceholder()) - { - return; - } - auto siblings = current->getNextBlocks(); - for (auto const& [key, block] : siblings) - { - if (!block->isPlaceholder() && block.get() != this) - { - return; - } - } - BlockPtr prev = current->getPrevBlock(); - current->detachFromLookupNode(); - current->setPrevBlockInSeq(nullptr); - current = prev; - } -} - void KVCacheBlock::freeBlockAndAllDescendants() { detachDescendantsFromLookupTree(); - detachPreviousPlaceholdersFromLookupTree(); detachFromLookupNode(); } @@ -896,23 +871,23 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind WindowBlockManager::~WindowBlockManager() { - float reusedUniqueBlocksPercentage = mReusedUniqueBlocks == 0 || mAllocTotalBlocks == 0 - ? 0 - : static_cast(mReusedUniqueBlocks) / static_cast(mAllocNewBlocks) * 100; - float cacheHitRate = mReusedBlocks == 0 - ? 0 - : static_cast(mReusedBlocks) / (static_cast(mReusedBlocks + mMissedBlocks)); - TLLM_LOG_DEBUG("%s - total allocated blocks: %lu ", mLogPrefix.c_str(), mAllocTotalBlocks); - TLLM_LOG_DEBUG("%s - allocated new blocks: %lu ", mLogPrefix.c_str(), mAllocNewBlocks); - TLLM_LOG_DEBUG("%s - missed blocks: %lu ", mLogPrefix.c_str(), mMissedBlocks); - TLLM_LOG_DEBUG("%s - reused blocks: %lu ", mLogPrefix.c_str(), mReusedBlocks); - TLLM_LOG_DEBUG("%s - reused unique blocks: %lu ", mLogPrefix.c_str(), mReusedUniqueBlocks); + double reusedUniqueBlocksPercentage = mAllocNewBlocks == 0 + ? 0.0 + : static_cast(mReusedUniqueBlocks) / static_cast(mAllocNewBlocks) * 100.0; + double cacheHitRate = (mReusedBlocks + mMissedBlocks) == 0 ? 0.0 + : static_cast(mReusedBlocks) + / (static_cast(mReusedBlocks) + static_cast(mMissedBlocks)); + TLLM_LOG_DEBUG("%s - total allocated blocks: %d ", mLogPrefix.c_str(), mAllocTotalBlocks); + TLLM_LOG_DEBUG("%s - allocated new blocks: %d ", mLogPrefix.c_str(), mAllocNewBlocks); + TLLM_LOG_DEBUG("%s - missed blocks: %d ", mLogPrefix.c_str(), mMissedBlocks); + TLLM_LOG_DEBUG("%s - reused blocks: %d ", mLogPrefix.c_str(), mReusedBlocks); + TLLM_LOG_DEBUG("%s - reused unique blocks: %d ", mLogPrefix.c_str(), mReusedUniqueBlocks); TLLM_LOG_DEBUG( "%s - reused unique blocks percentage (%%): %.2f ", mLogPrefix.c_str(), reusedUniqueBlocksPercentage); TLLM_LOG_DEBUG("%s - cache hit rate: %.2f ", mLogPrefix.c_str(), cacheHitRate); TLLM_LOG_DEBUG("%s - reused tokens: %.0f ", mLogPrefix.c_str(), mReusedTokens); TLLM_LOG_DEBUG("%s - reused tokens percentage (%%): %.2f ", mLogPrefix.c_str(), - 100.0 * mReusedTokens / mTotalInputTokens); + mTotalInputTokens == 0.0 ? 0.0 : 100.0 * mReusedTokens / mTotalInputTokens); } bool BlockManager::verifyQueueIntegrity(SizeType32 windowSize) @@ -1410,7 +1385,7 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati // Phase 1: Walk radix tree, claim matching blocks — no onboard, no getFreeBlock // NOTE: Caller must hold mCachedBlocksRootMutex. - // Compute shareLastContextBlockAmongBeams — aligned with loadOrAllocateBlocks (PR #10437). + // Compute shareLastContextBlockAmongBeams for the batch-add allocation path. auto const beamWidth = sequence.getBeamWidth(); bool const isShareLastContextBlock = mCacheType == CacheType::kCROSS || inputLength % mTokensPerBlock == 0; result.numSharedContextBlocks @@ -1922,197 +1897,6 @@ std::shared_ptr WindowBlockManager::searchReuseTree(std::vector const& blockKeys, SizeType32 inputLength, - SizeType32 numContextBlocks, GenerationRequest& sequence, - std::vector const& perBlockRetentions, executor::KvCacheTransferMode mode, - std::string const& directory, bool isEnableBlockReuse) -{ - std::lock_guard lock(mCachedBlocksRootMutex); - SizeType32 numMatchedTokens{0}; - SizeType32 latestMatchingNonPlaceholderBlockIdx{-1}; - auto searchRoot = mCachedBlocksRoot; - std::set reusedBlockIds; - - // The last block can be shared between beams if it is fully filled (won't be written to during generation) - // or if this is cross-attention KV cache (read-only). Otherwise, allocate a unique block per beam. - auto const beamWidth = sequence.getBeamWidth(); - bool const isShareLastContextBlock = mCacheType == CacheType::kCROSS || inputLength % mTokensPerBlock == 0; - SizeType32 numSharedContextBlocks - = (beamWidth > 1 && !isShareLastContextBlock) ? numContextBlocks - 1 : numContextBlocks; - - auto blockItr = blockKeys.begin(); - for (int bi = 0; bi < numSharedContextBlocks; ++bi) - { - auto [partialMatch, numMatched, matchingBlock] - = searchRoot != nullptr && blockItr != blockKeys.end() && isEnableBlockReuse - ? searchRoot->findMatchingBlock(*blockItr, mEnablePartialReuse, mCopyOnPartialReuse) - : std::make_tuple(false, 0, nullptr); - if (isRecurrentState()) - { - TLLM_CHECK(partialMatch == false); - } - if (matchingBlock != nullptr && numMatchedTokens + numMatched <= sequence.getCurrentPrepopulatedPromptLen()) - { - KVCacheBlock::IdType matchingBlockId = matchingBlock->getBlockId(); - - numMatchedTokens += numMatched > 0 ? numMatched : blockItr->uniqueTokens.size(); - if (!matchingBlock->isPlaceholder()) - { - latestMatchingNonPlaceholderBlockIdx = bi; - } - if (perBlockRetentions[bi].retentionPriority.has_value() - && matchingBlock->getPriority() != perBlockRetentions[bi].retentionPriority && mEventManager) - { - mEventManager->enqueueUpdatedEvent( - tle::KVCacheUpdatedData(matchingBlock->getHash()) - .priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority), - mWindowSize); - } - if (partialMatch) - { - if (matchingBlock->hasRefs() || !matchingBlock->isLeaf()) - { - // Somebody else is using block or it is not a leaf, copy reusable tokens - auto newBlock = getFreeBlock( - sequence, matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory); - mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory); - // TODO: (optional) Send out event - matchingBlock = newBlock; - if (blockItr != blockKeys.end()) - { - matchingBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); - } - matchingBlock->setHash(); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks for request %lu - Copied partially filled block %d", - mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); - } - else - { - // Leaf block that nobody is using. Make block private and reuse - freeLeafBlock(matchingBlock); - mEvictionPolicy->claimBlock( - matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks for request %lu - Reused partially filled block %d", - mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId); - } - searchRoot = nullptr; // no matching needed for following blocks - } - else - { - searchRoot = matchingBlock; - // Recover block and reuse - mEvictionPolicy->claimBlock( - matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks for request %lu - Matched full block %d", mLogPrefix.c_str(), - sequence.getRequestId(), matchingBlockId); - } - onboardBlock(sequence, matchingBlock, mode, directory); - addBlockToAllBeams(matchingBlock, sequence); - if (!matchingBlock->isPlaceholder()) - { - // TODO: only add once for reused blocks - ++mReusedBlocks; - if (!reusedBlockIds.count(matchingBlockId)) - { - reusedBlockIds.insert(matchingBlockId); - ++mReusedUniqueBlocks; - } - if (partialMatch) - { - ++mPartialReusedBlocks; - } - else - { - ++mFullReusedBlocks; - } - } - ++blockItr; - } - else // matchingBlock == nullptr || numMatchedTokens + numMatched > sequence.getCurrentPrepopulatedPromptLen() - { - BlockPtr freeBlock; - bool shouldAllocate = true; - if (isRecurrentState()) - { - if (isEnableBlockReuse) - { - // loadOrAllocateBlocks is only called by addSequence, which ensures it's the first chunk, so the - // token num always starts from 0. - shouldAllocate = mLinearAttentionMetadata->shouldAllocateRecurrentStates( - /*currentBlockEndTokenIdx=*/(bi + 1) * mTokensPerBlock, inputLength, mTokensPerBlock); - } - else - { - // When block reuse is disabled, only the last context block needs real memory to store the - // current recurrent state. All other blocks are placeholders. - shouldAllocate = (bi == numContextBlocks - 1); - } - TLLM_LOG_DEBUG( - "%s::loadOrAllocateBlocks - Recurrent state block %d. shouldAllocate=%d for sequence %lu", - mLogPrefix.c_str(), bi, shouldAllocate, sequence.getRequestId()); - } - - // If we haven't set a priority, set it to the default priority level (low) - freeBlock = getFreeBlock(sequence, - perBlockRetentions[bi].retentionPriority.value_or( - executor::KvCacheRetentionConfig::kDefaultRetentionPriority), - perBlockRetentions[bi].durationMs, mode, directory, /*wantPlaceholder=*/!shouldAllocate); - addBlockToAllBeams(freeBlock, sequence); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - No match, allocated new block %d for sequence %lu", - mLogPrefix.c_str(), freeBlock->getBlockId(), sequence.getRequestId()); - // allBlockStats.emplace_back(freeBlock, "N"); - searchRoot = nullptr; // no matching needed for following blocks - if (blockItr != blockKeys.end()) - { - freeBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); - ++blockItr; - } - freeBlock->setHash(); - ++mMissedBlocks; - } - } - - // Allocate new blocks that cannot be shared by multiple beams. - for (int bi = numSharedContextBlocks; bi < numContextBlocks; ++bi) - { - // TODO: Still look for match. Clone matching block or allocate fresh ones. - // This work is described in JIRA task https://jirasw.nvidia.com/browse/TRTLLM-2069. - for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx) - { - // If we haven't set a priority, set it to the default priority level (low) - auto freeBlock = getFreeBlock(sequence, - perBlockRetentions[bi].retentionPriority.value_or( - executor::KvCacheRetentionConfig::kDefaultRetentionPriority), - perBlockRetentions[bi].durationMs, mode, directory); - addBlockToBeam(freeBlock, sequence, beamIdx); - if (blockItr != blockKeys.end()) - { - freeBlock->setBlockKey( - *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); - ++blockItr; - } - freeBlock->setHash(); - TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", - mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); - } - ++mMissedBlocks; - if (blockItr != blockKeys.end()) - { - ++blockItr; - } - } - - if (isRecurrentState()) - { - // purge tailing placeholder blocks - numMatchedTokens = (latestMatchingNonPlaceholderBlockIdx + 1) * mTokensPerBlock; - } - sequence.setCurrentPrepopulatedPromptLen(numMatchedTokens); - return sequence.getCurrentPrepopulatedPromptLen(); -} - void BlockManager::syncTransferManagerWithBufferManager() { for (auto& [_, manager] : mWindowBlockManagers) @@ -2140,13 +1924,6 @@ void WindowBlockManager::refreshBlocks() mTransferManager->syncTransfers(); } -SizeType32 BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, - LlmRequest& llmRequest, SizeType32 windowSize, bool isEnableBlockReuse) -{ - return mWindowBlockManagers.at(windowSize) - .addSequence(sequence, inputLength, numContextBlocks, llmRequest, isEnableBlockReuse); -} - std::vector BlockManager::addSequenceBatch( std::vector const& sequences, std::vector const& inputLengths, std::vector const& numContextBlocksVec, @@ -2156,83 +1933,6 @@ std::vector BlockManager::addSequenceBatch( .addSequenceBatch(sequences, inputLengths, numContextBlocksVec, llmRequests, isEnableBlockReuse); } -SizeType32 WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, - SizeType32 numContextBlocks, LlmRequest& llmRequest, bool isEnableBlockReuse) -{ - TLLM_CHECK_WITH_INFO(!(isRecurrentState()) || inputLength == llmRequest.getPromptLen(), - "Recurrent state does not support CP or truncation yet."); - auto const requestId = sequence.getRequestId(); - auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); - TLLM_CHECK(emplaceDone); - - auto constexpr beamIdx = 0; - bool const isSelfCache = mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY; - - // For cross KV cache without encoder tokens (e.g., encoder-decoder models with feature inputs like Whisper), - // encoder unique tokens are not available. Use empty block keys since block reuse is disabled for cross - // KV cache and unique tokens are only needed for radix tree lookup. - bool const hasUniqueTokens = isSelfCache - || (llmRequest.getEncoderUniqueTokens().has_value() && llmRequest.getEncoderUniqueTokens().value()); - std::vector blockKeys; - VecUniqueTokens const* uniqueTokensPtr = nullptr; - - if (hasUniqueTokens) - { - auto const& uniqueTokens - = isSelfCache ? llmRequest.getUniqueTokens(beamIdx) : *(llmRequest.getEncoderUniqueTokens().value()); - uniqueTokensPtr = &uniqueTokens; - - // Ignore last token because it can't be recovered - auto blockedUniqueTokens - = chopVectorIntoBlocks(uniqueTokens, inputLength - 1, mTokensPerBlock, true); - // Add empty block if last token is separated - if (inputLength % mTokensPerBlock == 1) - { - blockedUniqueTokens.emplace_back(); - } - - blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); - } - - auto config = llmRequest.getKvCacheRetentionConfig(); - - auto perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig()) - .getPerBlockRetentionPriorityDuration(getTokensPerBlock(), inputLength); - - auto mode = config.value_or(executor::KvCacheRetentionConfig()).getTransferMode(); - auto directory = config.value_or(executor::KvCacheRetentionConfig()).getDirectory(); - - if (mode != executor::KvCacheTransferMode::DRAM && directory.empty()) - { - TLLM_LOG_WARNING( - "Transfer mode %d specified without directory, falling back to DRAM mode", static_cast(mode)); - mode = executor::KvCacheTransferMode::DRAM; - } - - TLLM_CHECK(perBlockRetentions.size() == (size_t) numContextBlocks); - - auto const prepopulatedPromptLen = loadOrAllocateBlocks( - blockKeys, inputLength, numContextBlocks, sequence, perBlockRetentions, mode, directory, isEnableBlockReuse); - mReusedTokens += static_cast(prepopulatedPromptLen); - mTotalInputTokens += static_cast(uniqueTokensPtr ? uniqueTokensPtr->size() : inputLength); - - SizeType32 numConnectorMatchedTokens = 0; - - // If we're using a KV cache connector, check if any additional blocks can be loaded. - if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) - { - numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); - } - - // Return the total prepopulated length for this window (do not set on llmRequest here - - // the caller KVCacheManager::addSequence will use the minimum across all windows) - auto const totalPrepopulatedLen = prepopulatedPromptLen + numConnectorMatchedTokens; - TLLM_LOG_DEBUG( - "%s::addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d", - mLogPrefix.c_str(), llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); - return totalPrepopulatedLen; -} - void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence) { for (auto& [windowSize, manager] : mWindowBlockManagers) @@ -3278,7 +2978,7 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(LlmRequest const& req, bool tw = cachedSummary.has_value() ? cachedSummary.value() : analyzePrefixReuse(req.getUniqueTokens(0), req); auto const numReusableBlocks = summary.reusableBlocksAllocated; auto const promptInputLen = std::min(req.mPromptLen, windowSize + chunkSize); - // `addSequence()` ignores the last prompt token because its KV cannot be recovered. + // Sequence insertion ignores the last prompt token because its KV cannot be recovered. // When the prompt lands exactly on a block boundary, counting reusable full blocks from // all unique tokens can over-credit one extra shared block. TLLM_CHECK_WITH_INFO(promptInputLen > 0, "Unexpected: promptInputLen == 0"); @@ -3548,90 +3248,6 @@ PrefixReuseSummary KVCacheManager::analyzePrefixReuse( return mBlockManager.analyzePrefixReuse(uniqueTokens, llmRequest); } -void KVCacheManager::addSequence( - RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest) -{ - // TODO: add streamLLM support - auto kvCacheRetentionConfig = llmRequest - ? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()) - : executor::KvCacheRetentionConfig(); - - auto const [seqIt, emplaceDone] = [&] - { - auto lck = std::scoped_lock(mSequencesMtx); - return mSequences.try_emplace(requestId, requestId, inputLength, beamWidth, - mBlockManager.getWindowSizesMetadata(), kvCacheRetentionConfig); - }(); - TLLM_CHECK(emplaceDone); - auto& sequence = seqIt->second; - - // Get statistics for block allocations/reuse pre request. - SizeType32 const numAllocTotalBlocksPreRequest = mBlockManager.getNumAllocTotalBlocks(); - SizeType32 const numAllocNewBlocksPreRequest = mBlockManager.getNumAllocNewBlocks(); - SizeType32 const numReusedBlocksPreRequest = mBlockManager.getNumReusedBlocks(); - SizeType32 const numMissedBlocksPreRequest = mBlockManager.getNumMissedBlocks(); - - if (!mBlockManager.isSequenceHeld(requestId)) - { - mBlockManager.holdSequence(requestId); - TLLM_LOG_DEBUG( - "[kv cache manager] Encounter new sequence %d, initialize sequence storage validity for all window sizes", - requestId); - } - else - { - TLLM_LOG_DEBUG( - "[kv cache manager] Encounter existing sequence %d, skip sequence storage validity initialization", - requestId); - } - // Track the minimum prepopulated length across all windows (for VSWA with mixed isSWA flags) - SizeType32 minPrepopulatedPromptLen = std::numeric_limits::max(); - - for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) - { - // NOTE: Caller to KVCacheManager::addSequence should deal with the chunking - auto const maxTokenNum = metadata.maxTokenNum; - auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow; - - // Consider the temporaryAttentionWindow when allocating blocks. - auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); - auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); - if (!mEnableBlockReuse && llmRequest && llmRequest->getKvCacheRetentionConfig().has_value()) - { - TLLM_LOG_WARNING( - "Request %d has a retention configuration set, but block reuse is disabled. The retention " - "config will have no effect.", - llmRequest->mRequestId); - } - auto const prepopulatedLen = mBlockManager.addSequence( - sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize, mEnableBlockReuse); - // Use the minimum prepopulated length across all windows to ensure correctness - // when there's a mix of SWA and non-SWA windows (e.g., VSWA case) - minPrepopulatedPromptLen = std::min(minPrepopulatedPromptLen, prepopulatedLen); - mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); - } - - // Set the prepopulated prompt length once using the minimum across all windows - if (llmRequest && mEnableBlockReuse) - { - TLLM_LOG_DEBUG("KVCacheManager::addSequence: Setting prepopulatedPromptLen to %d", minPrepopulatedPromptLen); - llmRequest->setPrepopulatedPromptLen(minPrepopulatedPromptLen, getTokensPerBlock()); - // Clear the scheduling estimate now that the authoritative value is set. - // This prevents subsequent chunks from double-counting reusable tokens. - llmRequest->setEstimatedReusableTokens(0); - } - - if (llmRequest) - { - // Update statistics for block allocations/reuse per request. - llmRequest->updateAllocTotalBlocksPerRequest( - mBlockManager.getNumAllocTotalBlocks() - numAllocTotalBlocksPreRequest); - llmRequest->updateAllocNewBlocksPerRequest(mBlockManager.getNumAllocNewBlocks() - numAllocNewBlocksPreRequest); - llmRequest->updateReusedBlocksPerRequest(mBlockManager.getNumReusedBlocks() - numReusedBlocksPreRequest); - llmRequest->updateMissedBlocksPerRequest(mBlockManager.getNumMissedBlocks() - numMissedBlocksPreRequest); - } -} - void KVCacheManager::addSequenceBatch( std::vector> const& requestInfos, std::vector> const& llmRequests) @@ -3673,7 +3289,7 @@ void KVCacheManager::addSequenceBatch( } // Track the minimum prepopulated length across all windows per sequence - // (for VSWA with mixed isSWA flags, mirrors KVCacheManager::addSequence logic) + // (for VSWA with mixed isSWA flags). std::vector minPrepopulatedLen(n, std::numeric_limits::max()); // Accumulate block allocation stats across all windows per sequence std::vector totalAllocTotalDelta(n, 0); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index 9b5b71377b7b..5cdb2171b0e6 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -265,7 +265,7 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // to a block. When a new block copy is scheduled, we wait for all writes to the source // block and all reads and writes to a destination block. // -// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence. +// As before, syncTransfers() must be called after the last call to KVCacheManager::addSequenceBatch. // Failing to do so will lead to corrupted blocks eventually. // diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 295114b00711..6fa0337cdfad 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -121,12 +121,6 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getTokenCount, requestId); } - void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override - { - NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); - } - void addSequenceBatch( std::vector> const& requestInfos, std::vector> const& llmRequests) override @@ -434,7 +428,6 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::arg("window_size"), nb::arg("cached_summary") = std::nullopt, nb::call_guard()) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("get_token_count", &BaseKVCacheManager::getTokenCount, nb::arg("request_id")) - .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) .def( "add_sequence_batch", [](tbk::BaseKVCacheManager& self, nb::list requestInfosList, nb::list llmRequestsList) diff --git a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp index 5082c2d7e219..4e8b58ae6fda 100644 --- a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -361,8 +361,9 @@ int runTest(CapacityScheduler& capacityScheduler, { if (llmReq->isDisaggGenerationInitState()) { - kvCacheManager->addSequence( - llmReq->mRequestId, llmReq->mPromptLen, llmReq->mSamplingConfig.beamWidth, llmReq); + kvCacheManager->addSequenceBatch( + {{{llmReq->mRequestId, llmReq->mPromptLen, llmReq->mSamplingConfig.beamWidth}}}, + {std::ref(*llmReq)}); llmReq->setState(LlmRequestState::kGENERATION_IN_PROGRESS); llmReq->setContextCurrentPosition(llmReq->mPromptLen); llmReq->setDecodingIter(1); @@ -385,12 +386,13 @@ int runTest(CapacityScheduler& capacityScheduler, if (llmReq->isFirstContextChunk()) { // We need to perform initialization work for the first context chunk. - kvCacheManager->addSequence( - llmReq->mRequestId, promptLen, llmReq->mSamplingConfig.beamWidth, llmReq); + kvCacheManager->addSequenceBatch( + {{{llmReq->mRequestId, promptLen, llmReq->mSamplingConfig.beamWidth}}}, {std::ref(*llmReq)}); if (crossKvCacheManager) { - crossKvCacheManager->addSequence(llmReq->mRequestId, llmReq->getEncoderOutputLen(), - llmReq->mSamplingConfig.beamWidth, llmReq); + crossKvCacheManager->addSequenceBatch( + {{{llmReq->mRequestId, llmReq->getEncoderOutputLen(), llmReq->mSamplingConfig.beamWidth}}}, + {std::ref(*llmReq)}); } } auto preContextLength = llmReq->getContextChunkSize(); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index d1cdeb3c16bb..24a648be5178 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -170,8 +170,8 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokensNotAligned, samplingConfig, isStreaming); GenerationRequest seq0{requestId, numTokensNotAligned, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq0.getRequestId()); - (void) blockManager.addSequence( - seq0, numTokensNotAligned, numBlocksPerBeam, *llmReq0, maxAttentionWindow, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0}, {numTokensNotAligned}, {numBlocksPerBeam}, {std::ref(*llmReq0)}, + maxAttentionWindow, /*isEnableBlockReuse=*/false); auto constexpr occupiedBlocks = (numBlocksPerBeam - 1) + beamWidth; EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - occupiedBlocks); auto const& ids = seq0.getCacheBlockIds(maxAttentionWindow); @@ -192,8 +192,8 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokensAligned, samplingConfig, isStreaming); GenerationRequest seq0b{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq0b.getRequestId()); - (void) blockManager.addSequence( - seq0b, numTokens, numBlocksPerBeam, *llmReq1, maxAttentionWindow, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0b}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq1)}, + maxAttentionWindow, /*isEnableBlockReuse=*/false); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocksPerBeam); auto const& idsShared = seq0b.getCacheBlockIds(maxAttentionWindow); EXPECT_EQ(idsShared.size(), beamWidth); @@ -212,29 +212,32 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokensNotAligned, samplingConfig, isStreaming); GenerationRequest seq0c{requestId, numTokensNotAligned, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq0c.getRequestId()); - EXPECT_NO_THROW((void) blockManager.addSequence( - seq0c, numTokensNotAligned, numBlocksPerBeam, *llmReq2, maxAttentionWindow, /*isEnableBlockReuse=*/false)); + EXPECT_NO_THROW((void) blockManager.addSequenceBatch({&seq0c}, {numTokensNotAligned}, {numBlocksPerBeam}, + {std::ref(*llmReq2)}, maxAttentionWindow, /*isEnableBlockReuse=*/false)); auto llmReq3 = std::make_shared( LlmRequest::RequestIdType{requestId + 1}, maxNewTokens, inputTokensNotAligned, samplingConfig, isStreaming); GenerationRequest seq1{requestId + 1, numTokensNotAligned, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq1.getRequestId()); - EXPECT_NO_THROW((void) blockManager.addSequence( - seq1, numTokensNotAligned, numBlocksPerBeam, *llmReq3, maxAttentionWindow, /*isEnableBlockReuse=*/false)); + EXPECT_NO_THROW((void) blockManager.addSequenceBatch({&seq1}, {numTokensNotAligned}, {numBlocksPerBeam}, + {std::ref(*llmReq3)}, maxAttentionWindow, + /*isEnableBlockReuse=*/false)); // same requestId not allowed auto llmReq4 = std::make_shared( LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokensNotAligned, samplingConfig, isStreaming); GenerationRequest seq2{requestId, numTokensNotAligned, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq2.getRequestId()); - EXPECT_THROW((void) blockManager.addSequence(seq2, numTokensNotAligned, numBlocksPerBeam, *llmReq4, - maxAttentionWindow, /*isEnableBlockReuse=*/false), + EXPECT_THROW((void) blockManager.addSequenceBatch({&seq2}, {numTokensNotAligned}, {numBlocksPerBeam}, + {std::ref(*llmReq4)}, maxAttentionWindow, + /*isEnableBlockReuse=*/false), std::runtime_error); // no more blocks auto llmReq5 = std::make_shared( LlmRequest::RequestIdType{requestId + 2}, maxNewTokens, inputTokensNotAligned, samplingConfig, isStreaming); GenerationRequest seq3{requestId + 2, numTokensNotAligned, beamWidth, blockManager.getWindowSizesMetadata()}; blockManager.holdSequence(seq3.getRequestId()); - EXPECT_THROW((void) blockManager.addSequence(seq3, numTokensNotAligned, numBlocksPerBeam, *llmReq5, - maxAttentionWindow, /*isEnableBlockReuse=*/false), + EXPECT_THROW((void) blockManager.addSequenceBatch({&seq3}, {numTokensNotAligned}, {numBlocksPerBeam}, + {std::ref(*llmReq5)}, maxAttentionWindow, + /*isEnableBlockReuse=*/false), std::runtime_error); } @@ -346,8 +349,11 @@ void runPartialCopyTest() auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); auto cacheBlockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -397,8 +403,11 @@ void runPartialCopyTest() auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 16); auto cacheBlockIds1 = seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -425,8 +434,11 @@ void runPartialCopyTest() auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 11); auto cacheBlockIds2 = seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx); @@ -1021,8 +1033,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) // add request with 11 tokens again and make sure no discarded tokens reuse happens // input tokens [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12] // reuse blocks 0, 1, 2(p) ([0, 1, 2, 3], [4, 5, 6, 7], [8]) - // nb! LlmRequest retains state calculated during addSequence, this state affects result. - // Calling addSequence a second time with same LlmRequest object will produce incorrect state. + // nb! LlmRequest retains state calculated during addSequenceBatch, this state affects result. + // Calling addSequenceBatch a second time with same LlmRequest object will produce incorrect state. // Create new llmRequest4 instance to avoid this issue. GenerationRequest seq4_dup{14, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; llmRequest4 = std::make_shared( @@ -1166,8 +1178,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1201,8 +1216,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1231,8 +1249,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); - prepopulatedPromptLen0 = blockManager.addSequence( - seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0_dup}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); @@ -1255,8 +1276,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); - prepopulatedPromptLen1 = blockManager.addSequence( - seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1_dup}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1293,8 +1317,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); @@ -1321,8 +1348,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence( - seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8, 9})); @@ -1402,8 +1432,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1443,8 +1476,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1482,8 +1518,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6})); @@ -1518,8 +1557,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence( - seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset @@ -1592,8 +1634,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // get new blocks 0, 1, 2 ([0,1,2,3], [4,5,6,7], [8]) blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1625,8 +1670,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); @@ -1654,8 +1702,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) promptLen0 = llmRequest0->getNumTokens(beamIdx); numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0_dup.getRequestId()); - prepopulatedPromptLen0 = blockManager.addSequence( - seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0_dup}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); // nb! addNewToken adds new generated token, number of input tokens stay the same. // calling addNewToken before addSequence potentially triggers this error message: @@ -1679,8 +1730,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); // reuse 0, 1, 2(p) ([0,1,2,3], [4,5,6,7], [8]) blockManager.holdSequence(seq1_dup.getRequestId()); - prepopulatedPromptLen1 = blockManager.addSequence( - seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1_dup}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1715,8 +1769,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); // no reuse expected. Input tokens match blocks 0 and 1, but lora task id differs. EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); @@ -1747,8 +1804,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence( - seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), promptLen3 - 2); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({5, 6, 7})); @@ -1780,8 +1840,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - auto prepopulatedPromptLen4 = blockManager.addSequence( - seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen4 = blockManager + .addSequenceBatch({&seq4}, {promptLen4}, {numContextBlocks4}, + {std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 8})); @@ -1808,8 +1871,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto promptLen5 = llmRequest5->getNumTokens(beamIdx); auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq5.getRequestId()); - auto prepopulatedPromptLen5 = blockManager.addSequence( - seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen5 = blockManager + .addSequenceBatch({&seq5}, {promptLen5}, {numContextBlocks5}, + {std::ref(*llmRequest5)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); EXPECT_THAT(seq5.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11})); @@ -1881,8 +1947,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -1917,8 +1986,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -1947,8 +2019,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); // reuse blocks 0, 1 and get new block 6 blockManager.holdSequence(seq0_dup.getRequestId()); - prepopulatedPromptLen0 = blockManager.addSequence( - seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0_dup}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); llmRequest0->addNewToken(3, beamIdx); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 2 * tokensPerBlock); @@ -1971,8 +2046,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) promptLen1 = llmRequest1->getNumTokens(beamIdx); numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1_dup.getRequestId()); - prepopulatedPromptLen1 = blockManager.addSequence( - seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1_dup}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), llmRequest1->getNumTokens(beamIdx) - 1); EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -2008,8 +2086,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); @@ -2036,8 +2117,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence( - seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 10, 11})); @@ -2063,8 +2147,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - auto prepopulatedPromptLen4 = blockManager.addSequence( - seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen4 = blockManager + .addSequenceBatch({&seq4}, {promptLen4}, {numContextBlocks4}, + {std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), tokensPerBlock); EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 12, 13})); @@ -2146,8 +2233,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); @@ -2187,8 +2277,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence( - seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); @@ -2223,8 +2316,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence( - seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4 EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6})); @@ -2260,8 +2356,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence( - seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); @@ -2289,8 +2388,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - auto prepopulatedPromptLen4 = blockManager.addSequence( - seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen4 = blockManager + .addSequenceBatch({&seq4}, {promptLen4}, {numContextBlocks4}, + {std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1 EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10})); @@ -2415,8 +2517,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq0{0, inputLength0, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks0 = tc::ceilDiv(inputLength0, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq0.getRequestId()); - auto prepopulatedPromptLen0 = blockManager.addSequence(seq0, llmRequest0->getNumTokens(0), numContextBlocks0, - *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen0 = blockManager + .addSequenceBatch({&seq0}, {llmRequest0->getNumTokens(0)}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock()); // Add another sequence with different tokens, at a low priority @@ -2426,8 +2531,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq1{1, inputLength1, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks1 = tc::ceilDiv(inputLength1, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq1.getRequestId()); - auto prepopulatedPromptLen1 = blockManager.addSequence(seq1, llmRequest1->getNumTokens(0), numContextBlocks1, - *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen1 = blockManager + .addSequenceBatch({&seq1}, {llmRequest1->getNumTokens(0)}, {numContextBlocks1}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock()); // Release both sequences @@ -2447,8 +2555,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq2{2, inputLength2, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks2 = tc::ceilDiv(inputLength2, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq2.getRequestId()); - auto prepopulatedPromptLen2 = blockManager.addSequence(seq2, llmRequest2->getNumTokens(0), numContextBlocks2, - *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen2 = blockManager + .addSequenceBatch({&seq2}, {llmRequest2->getNumTokens(0)}, {numContextBlocks2}, + {std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock()); tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmRequest2); blockManager.releaseBlocks(seq2, llmRequest2); @@ -2461,8 +2572,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq3{3, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks3 = tc::ceilDiv(inputLength3, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq3.getRequestId()); - auto prepopulatedPromptLen3 = blockManager.addSequence(seq3, llmRequest3->getNumTokens(0), numContextBlocks3, - *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen3 = blockManager + .addSequenceBatch({&seq3}, {llmRequest3->getNumTokens(0)}, {numContextBlocks3}, + {std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 4); @@ -2479,8 +2593,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq4{4, inputLength3, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks4 = tc::ceilDiv(inputLength4, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq4.getRequestId()); - auto prepopulatedPromptLen4 = blockManager.addSequence(seq4, llmRequest4->getNumTokens(0), numContextBlocks4, - *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen4 = blockManager + .addSequenceBatch({&seq4}, {llmRequest4->getNumTokens(0)}, {numContextBlocks4}, + {std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 4); @@ -2492,8 +2609,11 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) GenerationRequest seq5{5, inputLength5, beamWidth, blockManager.getWindowSizesMetadata()}; auto numContextBlocks5 = tc::ceilDiv(inputLength5, blockManager.getTokensPerBlock()); blockManager.holdSequence(seq5.getRequestId()); - auto prepopulatedPromptLen5 = blockManager.addSequence(seq5, llmRequest5->getNumTokens(0), numContextBlocks5, - *llmRequest5, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto prepopulatedPromptLen5 = blockManager + .addSequenceBatch({&seq5}, {llmRequest5->getNumTokens(0)}, {numContextBlocks5}, + {std::ref(*llmRequest5)}, maxAttentionWindow, /*isEnableBlockReuse=*/true) + .front() + .prepopulatedLen; llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock()); EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 0); @@ -6250,11 +6370,11 @@ TEST(KVCacheManagerReuseAccountingTest, ReuseAwareBlockEstimatesStayConsistentAf // Verify estimatedReusableTokens is still set after getRemainingBlocksToCompletion EXPECT_EQ(req1.getEstimatedReusableTokens(), expectedReusableBlocks * tokensPerBlock); - // After addSequence, context blocks are allocated (reuse already applied during allocation) + // After addSequenceBatch, context blocks are allocated (reuse already applied during allocation) // Only generation blocks remain to be allocated kvCacheManager->addSequenceBatch({{{req1.mRequestId, req1.getPromptLen(), maxBeamWidth}}}, {std::ref(req1)}); - // Verify estimatedReusableTokens is cleared to 0 after addSequence + // Verify estimatedReusableTokens is cleared to 0 after addSequenceBatch EXPECT_EQ(req1.getEstimatedReusableTokens(), 0); auto const remainingAfterContextAlloc = kvCacheManager->getRemainingBlocksToCompletion(req1, onlyWindowSize); @@ -6849,7 +6969,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventRemovedOrderedBeforeStore) (void) getEvents(kvCacheManager); // drain // Seq1 with different tokens. - // addSequence: evicts seq0's block0 (and its descendant block1) — removes buffered, not yet emitted. + // addSequenceBatch: evicts seq0's block0 (and its descendant block1) — removes buffered, not yet emitted. // storeContextBlocks: calls flushRemovedEvents(W) first, committing the buffered removes, // then appends the Stored event for seq1's new blocks. auto inputTokens1 = std::make_shared(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108}); @@ -6944,7 +7064,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStoreForDifferentWindowDoesNotFlus (void) getEvents(kvCacheManager); // drain // Seq1 with different tokens (9 tokens → 3 blocks per window). - // addSequence for each window: gets block3 (fresh, no event), block2 (not in tree, no event), + // addSequenceBatch for each window: gets block3 (fresh, no event), block2 (not in tree, no event), // then block1 (in tree as leaf) → freeChildren(block1) → Removed(block1) buffered for that window. // storeContextBlocks: // wSWA: skipped (SWA) — wSWA removes stay buffered @@ -7057,10 +7177,10 @@ void testBlockManagerLinearAttention_ContextNoReuse(int beamWidth, int numTokens auto llmReq0 = std::make_shared( LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokens0, samplingConfig, isStreaming); GenerationRequest seq0{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; - (void) blockManager.addSequence( - seq0, numTokens, numBlocksPerBeam, *llmReq0, linearWindowSizeCode, /*isEnableBlockReuse=*/false); - (void) blockManager.addSequence( - seq0, numTokens, numBlocksPerBeam, *llmReq0, maxAttentionWindow, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq0)}, + linearWindowSizeCode, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq0)}, + maxAttentionWindow, /*isEnableBlockReuse=*/false); blockManager.holdSequence(seq0.getRequestId()); // When block reuse is disabled, only the last context block has real memory. // Whether the last block is shared depends on whether inputLength is aligned to tokensPerBlock. @@ -7098,10 +7218,10 @@ void testBlockManagerLinearAttention_ContextNoReuse(int beamWidth, int numTokens TLLM_LOG_DEBUG("=========================================================="); // reuse disabled: re-add after release, verify block sharing and count - (void) blockManager.addSequence( - seq0, numTokens, numBlocksPerBeam, *llmReq0, linearWindowSizeCode, /*isEnableBlockReuse=*/false); - (void) blockManager.addSequence( - seq0, numTokens, numBlocksPerBeam, *llmReq0, maxAttentionWindow, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq0)}, + linearWindowSizeCode, /*isEnableBlockReuse=*/false); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq0)}, + maxAttentionWindow, /*isEnableBlockReuse=*/false); ASSERT_EQ( blocksInPrimaryPool - blockManager.getNumFreeBlocksPerWindowSize()[linearWindowSizeCode], occupiedBlocksLinear); auto const& ids2 = seq0.getCacheBlockIds(linearWindowSizeCode); @@ -7129,15 +7249,17 @@ void testBlockManagerLinearAttention_ContextNoReuse(int beamWidth, int numTokens auto llmReqLoop = std::make_shared(LlmRequest::RequestIdType{static_cast(requestId + 1 + i)}, maxNewTokens, inputTokens0, samplingConfig, isStreaming); - ASSERT_NO_THROW((void) blockManager.addSequence( - seq, numTokens, numBlocksPerBeam, *llmReqLoop, linearWindowSizeCode, /*isEnableBlockReuse=*/false)); + ASSERT_NO_THROW((void) blockManager.addSequenceBatch({&seq}, {numTokens}, {numBlocksPerBeam}, + {std::ref(*llmReqLoop)}, linearWindowSizeCode, + /*isEnableBlockReuse=*/false)); } // no more blocks GenerationRequest seq3{requestId + 1 + i, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; auto llmReq3 = std::make_shared(LlmRequest::RequestIdType{static_cast(requestId + 1 + i)}, maxNewTokens, inputTokens0, samplingConfig, isStreaming); - ASSERT_THROW(blockManager.addSequence( - seq3, numTokens, numBlocksPerBeam, *llmReq3, linearWindowSizeCode, /*isEnableBlockReuse=*/false), + ASSERT_THROW(blockManager.addSequenceBatch({&seq3}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq3)}, + linearWindowSizeCode, + /*isEnableBlockReuse=*/false), std::runtime_error); } @@ -7186,12 +7308,10 @@ void testBlockManagerLinearAttention_ContextReuse(int beamWidth, int numTokens0, // reuse enabled: basic allocation GenerationRequest seq0{requestId, numTokens0, beamWidth, blockManager.getWindowSizesMetadata()}; - (void) blockManager.addSequence(seq0, numTokens0, tc::ceilDiv(numTokens0, tokensPerBlock), *llmRequest0, - linearWindowSizeCode, - /*isEnableBlockReuse=*/true); - (void) blockManager.addSequence(seq0, numTokens0, tc::ceilDiv(numTokens0, tokensPerBlock), *llmRequest0, - maxAttentionWindow, - /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens0}, {tc::ceilDiv(numTokens0, tokensPerBlock)}, + {std::ref(*llmRequest0)}, linearWindowSizeCode, /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seq0}, {numTokens0}, {tc::ceilDiv(numTokens0, tokensPerBlock)}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true); blockManager.holdSequence(seq0.getRequestId()); ASSERT_EQ(llmRequest0->getContextCurrentPosition(), 0); int regularSnapshots = numTokens0 / linearAttentionMetadata.statesSnapshotInterval; @@ -7247,12 +7367,10 @@ void testBlockManagerLinearAttention_ContextReuse(int beamWidth, int numTokens0, auto llmRequestNoise = std::make_shared(9999, numTokens1, inputTokensNoise, samplingConfig, isStreaming); GenerationRequest seqNoise{9999, numTokens1, beamWidth, blockManager.getWindowSizesMetadata()}; - (void) blockManager.addSequence(seqNoise, numTokens1, tc::ceilDiv(numTokens1, tokensPerBlock), *llmRequestNoise, - linearWindowSizeCode, - /*isEnableBlockReuse=*/true); - (void) blockManager.addSequence(seqNoise, numTokens1, tc::ceilDiv(numTokens1, tokensPerBlock), *llmRequestNoise, - maxAttentionWindow, - /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seqNoise}, {numTokens1}, {tc::ceilDiv(numTokens1, tokensPerBlock)}, + {std::ref(*llmRequestNoise)}, linearWindowSizeCode, /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seqNoise}, {numTokens1}, {tc::ceilDiv(numTokens1, tokensPerBlock)}, + {std::ref(*llmRequestNoise)}, maxAttentionWindow, /*isEnableBlockReuse=*/true); blockManager.holdSequence(seqNoise.getRequestId()); auto inputTokens1 = std::make_shared(); @@ -7267,12 +7385,10 @@ void testBlockManagerLinearAttention_ContextReuse(int beamWidth, int numTokens0, auto llmRequest1 = std::make_shared(1, numTokens1, inputTokens1, samplingConfig, isStreaming); GenerationRequest seq1{1, numTokens1, beamWidth, blockManager.getWindowSizesMetadata()}; - (void) blockManager.addSequence(seq1, numTokens1, tc::ceilDiv(numTokens1, tokensPerBlock), *llmRequest1, - linearWindowSizeCode, - /*isEnableBlockReuse=*/true); - (void) blockManager.addSequence(seq1, numTokens1, tc::ceilDiv(numTokens1, tokensPerBlock), *llmRequest1, - maxAttentionWindow, - /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seq1}, {numTokens1}, {tc::ceilDiv(numTokens1, tokensPerBlock)}, + {std::ref(*llmRequest1)}, linearWindowSizeCode, /*isEnableBlockReuse=*/true); + (void) blockManager.addSequenceBatch({&seq1}, {numTokens1}, {tc::ceilDiv(numTokens1, tokensPerBlock)}, + {std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true); blockManager.holdSequence(seq1.getRequestId()); @@ -7793,7 +7909,7 @@ INSTANTIATE_TEST_SUITE_P(BlockManagerLinearAttention, LinearAttentionBlockCopyin )); /////////////////////////////////////////////////////////////////////////////// -// Batch addSequenceBatch corner-case tests +// addSequenceBatch corner-case tests // // These tests verify the two-phase claim-then-onboard strategy when multiple // requests in a single addSequenceBatch call compete for the same radix tree @@ -8165,9 +8281,9 @@ TEST_F(KVCacheManagerTest, BatchAddSequence_NonLeafCopySourceTightPool) // Request that partially matches block0 (non-leaf) and needs ALL remaining blocks. // Tokens: [0,1,50,...] → partial match on block0 (2 tokens), then needs many fresh blocks. - // Total: 29 tokens = 8 blocks (ceil(29/4)=8). All 8 pool blocks needed. + // Total: 32 tokens = 8 blocks (32 / 4 = 8). All 8 pool blocks are needed. auto bigTokens = std::make_shared(VecTokens{0, 1, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80}); + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}); auto req = std::make_shared( LlmRequest::RequestIdType{10}, SizeType32{0}, bigTokens, tr::SamplingConfig{beamWidth}, false); auto inputLen = static_cast(bigTokens->size()); @@ -8175,12 +8291,10 @@ TEST_F(KVCacheManagerTest, BatchAddSequence_NonLeafCopySourceTightPool) // Without the shouldReleaseCopySource fix, this would throw "No free block found" // because the claimed non-leaf copy source would not be released. - if (numBlocks <= blocksInPrimaryPool) - { - EXPECT_NO_THROW(kvCacheManager.addSequenceBatch({{{10, inputLen, beamWidth}}}, {std::ref(*req)})); - tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*req); - (void) kvCacheManager.removeSequence(10, req); - } + ASSERT_LE(numBlocks, blocksInPrimaryPool); + EXPECT_NO_THROW(kvCacheManager.addSequenceBatch({{{10, inputLen, beamWidth}}}, {std::ref(*req)})); + tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*req); + (void) kvCacheManager.removeSequence(10, req); } // Test 8: Mixed batch — one request fully matches, another has no match at all. diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp index 6c4dfff913fd..b8574262d49a 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -121,8 +121,9 @@ TEST_F(BlockIteratorTest, CacheManagerTest) auto constexpr beamIdx = 0; auto promptLen0 = llmRequest0->getNumTokens(beamIdx); auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); - blockManager.addSequence( - seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true); + auto const batchSeqStats = blockManager.addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0}, + {std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true); + ASSERT_THAT(batchSeqStats, ::testing::SizeIs(1)); auto const blockIds = seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(blockIds, ::testing::ElementsAreArray({0, 1, 2})); diff --git a/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp index 881e2c706f37..20ac06eaacd5 100644 --- a/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -1301,8 +1301,8 @@ TEST_F(CombinedSchedulerTest, CapacitySchedulerSetsReusableTokensForMicroBatch) // Request 0 should be scheduled ASSERT_GE(scheduled0.size(), 1u); - // Process request 0: addSequence → complete context → store blocks - kvCacheManager->addSequence(req0->mRequestId, promptLen, /*beamWidth=*/1, req0); + // Process request 0: addSequenceBatch → complete context → store blocks + kvCacheManager->addSequenceBatch({{{req0->mRequestId, promptLen, /*beamWidth=*/1}}}, {std::ref(*req0)}); req0->moveToNextContextChunk(); tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*req0); kvCacheManager->storeContextBlocks(*req0); @@ -1403,7 +1403,7 @@ TEST_F(CombinedSchedulerTest, CapacitySchedulerReusableTokensWithChunkedMicroBat auto [scheduled0, disaggInit0, paused0] = capacityScheduler(activeList, *kvCacheManager, /*peftCacheManager=*/std::nullopt); - kvCacheManager->addSequence(req0->mRequestId, promptLen, /*beamWidth=*/1, req0); + kvCacheManager->addSequenceBatch({{{req0->mRequestId, promptLen, /*beamWidth=*/1}}}, {std::ref(*req0)}); req0->moveToNextContextChunk(); tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*req0); kvCacheManager->storeContextBlocks(*req0); diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index fae2e7d3a365..3f9fd5cc707b 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -333,7 +333,8 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- { auto constexpr beamIdx{0}; auto constexpr beamWidth{1}; - mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest); + mManager->addSequenceBatch( + {{{llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth}}}, {std::ref(*llmRequest)}); if (isSender) { auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); @@ -927,7 +928,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParammLlmRequest; - mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest); + mManager->addSequenceBatch( + {{{llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth}}}, {std::ref(*llmRequest)}); auto blockRange = BlockRange::fromAllBlockIds(*mManager, llmRequest->mRequestId); int const numPools = mManager->getBlockManager().getNumPools( @@ -979,7 +981,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParammLlmRequest; - mManager->addSequence(llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth, llmRequest); + mManager->addSequenceBatch( + {{{llmRequest->mRequestId, llmRequest->getNumTokens(beamIdx), beamWidth}}}, {std::ref(*llmRequest)}); return mRequester->receiveAsync(*llmRequest); } diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index af049007ecce..d6122d573d02 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -22,7 +22,7 @@ from ..memory_buffer_utils import Buffers from ..metadata import KVCacheParams -from ..pyexecutor.mamba_cache_manager import MambaCacheManager +from ..pyexecutor.mamba_cache_manager import BaseMambaCacheManager from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2 from ..utils import get_model_extra_attrs @@ -305,8 +305,7 @@ def _prepare_mamba_metadata(self): return if self.mamba_metadata is None: - if (self.kv_cache_manager is not None - and isinstance(self.kv_cache_manager, MambaCacheManager)): + if isinstance(self.kv_cache_manager, BaseMambaCacheManager): from ..modules.mamba.mamba2_metadata import Mamba2Metadata self.mamba_metadata = Mamba2Metadata(self.max_num_requests, self.mamba_chunk_size) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index e020bb6a48ac..82896a888a68 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -51,7 +51,7 @@ from ...._utils import get_free_port, mpi_rank, mpi_world_size from ....mapping import Mapping from ...distributed import Distributed -from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from ...pyexecutor.mamba_cache_manager import BaseMambaCacheManager from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine from ...pyexecutor.py_executor import PyExecutor from ...pyexecutor.resource_manager import ( @@ -293,7 +293,7 @@ def _generate_dummy_request( ) # check if it's a hybrid kv-cache manager - is_hybrid_cache = isinstance(kv_cache_manager, MambaHybridCacheManager) + is_hybrid_cache = isinstance(kv_cache_manager, BaseMambaCacheManager) # check if we have a free page and free state available if not kv_cache_manager.get_num_free_blocks(): diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 4ce155c601f2..027e977eb22c 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -13,7 +13,10 @@ from tensorrt_llm.mapping import Mapping from ...._utils import torch_dtype_to_binding - from ...pyexecutor.mamba_cache_manager import MambaHybridCacheManager + from ...pyexecutor.mamba_cache_manager import ( + MambaHybridCacheManager, + MixedMambaHybridCacheManager, + ) from ...pyexecutor.resource_manager import KVCacheManager CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType @@ -523,7 +526,7 @@ def _create_and_assign_state_views( num_managed_mamba_layers = mamba_params["mamba_num_layers"] # Create the hybrid cache manager - manager = MambaHybridCacheManager( + manager = MixedMambaHybridCacheManager( **mamba_params, **kv_cache_kwargs, ) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f9988fc9442d..ad6d44289951 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -4,27 +4,32 @@ import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar import filelock import torch import transformers from transformers.utils import HF_MODULES_CACHE -from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import ( - get_qwen3_hybrid_num_attention_layers, is_nemotron_hybrid, is_qwen3_hybrid, - load_pretrained_config) + get_qwen3_hybrid_num_attention_layers, is_hybrid_linear, is_nemotron_hybrid, + is_qwen3_hybrid, load_pretrained_config) from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig, - MoeLoadBalancerConfig) + KvCacheConfig, MoeLoadBalancerConfig) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo +if TYPE_CHECKING: + from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp + from tensorrt_llm.llmapi.llm_args import (DecodingBaseConfig, LoraConfig, + SparseAttentionConfig, + SpeculativeConfig) + TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -636,9 +641,13 @@ def _recursive_update_config(config: transformers.PretrainedConfig, model_config._frozen = True return model_config - def get_bindings_model_config(self, - tokens_per_block: Optional[int] = None - ) -> "ModelConfigCpp": + def get_bindings_model_config( + self, + is_disagg: bool = False, + tokens_per_block: Optional[int] = None, + kv_cache_config: Optional[KvCacheConfig] = None, + spec_config: Optional['SpeculativeConfig'] = None, + ) -> "ModelConfigCpp": """ This method is used to construct the bindings config for the model. Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes @@ -667,7 +676,8 @@ def ceil_div(a, b): hidden_size = ceil_div(self.pretrained_config.hidden_size, attn_tp_size) num_layers = self.pretrained_config.num_hidden_layers - num_attention_layers = self.get_num_attention_layers() + num_attention_layers = self.get_num_attention_layers( + is_disagg, kv_cache_config, spec_config) if (self.spec_config is not None and self.spec_config.spec_dec_mode.is_mtp_one_model()): num_layers += self.spec_config.num_nextn_predict_layers @@ -693,6 +703,7 @@ def ceil_div(a, b): num_key_value_heads = getattr(self.pretrained_config, "num_key_value_heads", num_heads) + if isinstance(num_key_value_heads, (list, tuple)): # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models) num_kv_heads_per_layer = [ @@ -796,10 +807,35 @@ def get_layer_types(self) -> Optional[List[LayerTypeCpp]]: else: return None - def get_num_attention_layers(self): - if is_nemotron_hybrid(self.pretrained_config): + def get_num_attention_layers( + self, + is_disagg: bool, + kv_cache_config: Optional[KvCacheConfig] = None, + spec_config: Optional['SpeculativeConfig'] = None): + """Return the number of layers that need KV cache blocks. + + For hybrid models using the MixedMambaHybridCacheManager path + (TRTLLM_USE_CPP_MAMBA=1 for disagg), only attention layers need KV + cache blocks, so we return the attention-only count. + + For the default CppMambaHybridCacheManager path (including speculative + decoding), both attention and mamba layers are managed in the unified + KV cache pool, so we return num_hidden_layers (all layers). + """ + use_disagg = is_disagg or os.environ.get('TRTLLM_USE_CPP_MAMBA', + '0') == '1' + use_reuse = kv_cache_config is not None and kv_cache_config.enable_block_reuse + use_spec = spec_config is not None + + use_v1_mamba_manager = use_disagg or use_spec + if is_hybrid_linear( + self.pretrained_config) and use_v1_mamba_manager and use_reuse: + logger.warning( + "Block reuse does not work with MTP or disagg for hybrid linear models" + ) + if is_nemotron_hybrid(self.pretrained_config) and use_v1_mamba_manager: return self.pretrained_config.hybrid_override_pattern.count("*") - elif is_qwen3_hybrid(self.pretrained_config): + elif is_qwen3_hybrid(self.pretrained_config) and use_v1_mamba_manager: return get_qwen3_hybrid_num_attention_layers(self.pretrained_config) else: return self.pretrained_config.num_hidden_layers diff --git a/tensorrt_llm/_torch/modules/fla/chunk.py b/tensorrt_llm/_torch/modules/fla/chunk.py index e70f25d6e18a..ced64e46b0f7 100644 --- a/tensorrt_llm/_torch/modules/fla/chunk.py +++ b/tensorrt_llm/_torch/modules/fla/chunk.py @@ -79,7 +79,7 @@ def chunk_gated_delta_rule_fwd( class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @staticmethod - @input_guard + @input_guard(exclude_args="initial_state") @autocast_custom_fwd def forward( ctx, diff --git a/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py b/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py index 2f8d9dbcf980..d1092ed7bda9 100644 --- a/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py +++ b/tensorrt_llm/_torch/modules/fla/chunk_delta_h.py @@ -48,6 +48,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( cu_seqlens, chunk_offsets, T, + stride_h0, H: tl.constexpr, Hg: tl.constexpr, K: tl.constexpr, @@ -96,7 +97,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( stride_w = H * K if USE_INDEXED_STATE: state_index = tl.load(h0_i + i_n).to(tl.int64) - h0 = h0 + state_index * stride_h + h0 = h0 + state_index * stride_h0 ht = h0 if USE_INITIAL_STATE: h0 = h0 + ((i_h if USE_INDEXED_STATE else i_nh) * K * V) @@ -300,6 +301,7 @@ def grid(meta): cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, T=T, + stride_h0=initial_state.stride(0) if initial_state is not None else 0, H=H, Hg=Hg, K=K, diff --git a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py index bcc9ce835838..eeec20d30f05 100644 --- a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py @@ -765,9 +765,14 @@ def forward( state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size) if num_prefills > 0: - ssm_states[state_indices_p] = torch.zeros( + # PyExecutor guarantees prefill requests are placed before decode requests + has_initial_states_p = has_initial_states[:num_prefills] + ssm_states[state_indices_p[~has_initial_states_p]] = torch.zeros( (), dtype=ssm_states.dtype, device=ssm_states.device ) + conv_states[state_indices_p[~has_initial_states_p]] = torch.zeros( + (), dtype=conv_states.dtype, device=conv_states.device + ) is_target_verify = ( num_decodes > 0 diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 7688ae72771c..e0254ee7783b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -29,14 +29,14 @@ from ..model_config import ModelConfig from ..speculative import (get_num_extra_kv_tokens, get_num_spec_layers, get_spec_decoder, should_use_separate_draft_kv_cache) -from .config_utils import (get_qwen3_hybrid_layer_masks, is_mla, - is_nemotron_hybrid, is_qwen3_hybrid) +from .config_utils import (get_qwen3_hybrid_layer_masks, is_hybrid_linear, + is_mla, is_nemotron_hybrid, is_qwen3_hybrid) from .connectors.kv_cache_connector import KvCacheConnectorManager from .dwdp import DwdpManager from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse -from .mamba_cache_manager import MambaHybridCacheManager +from .mamba_cache_manager import BaseMambaCacheManager, MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor from .resource_manager import (KVCacheManager, KVCacheManagerV2, @@ -62,7 +62,7 @@ def get_kv_cache_manager_cls(model_config: ModelConfig, sparse_attn_config = model_config.sparse_attention_config if sparse_attn_config is not None: return get_sparse_attn_kv_cache_manager(sparse_attn_config) - elif is_nemotron_hybrid(config) or is_qwen3_hybrid(config): + elif is_hybrid_linear(config): return MambaHybridCacheManager else: return KVCacheManagerV2 if kv_cache_config.use_kv_cache_manager_v2 else KVCacheManager @@ -95,6 +95,7 @@ def __init__( speculative_config: SpeculativeConfig, sparse_attention_config: SparseAttentionConfig, profiling_stage_data: Optional[dict], + is_disagg: bool, execution_stream: Optional[torch.cuda.Stream] = None, draft_config: Optional[ModelConfig] = None, skip_est: bool = False, @@ -118,6 +119,7 @@ def __init__( self._net_max_seq_len = net_max_seq_len self._dummy_reqs = None self._profiling_stage_data = profiling_stage_data + self._is_disagg = is_disagg self._execution_stream = execution_stream self._kv_cache_manager_cls = self._get_model_kv_cache_manager_cls( model_engine) @@ -580,6 +582,7 @@ def _create_kv_cache_manager( estimating_kv_cache=estimating_kv_cache, execution_stream=self._execution_stream, layer_mask=spec_dec_layer_mask, + is_disagg=self._is_disagg, ) if not self._skip_est: @@ -722,6 +725,7 @@ def _create_one_model_draft_kv_cache_manager( is_draft=True, layer_mask=spec_dec_layer_mask, num_layers=num_draft_layers, + is_disagg=self._is_disagg, ) def _split_kv_cache_budget_for_draft(self) -> Optional[KvCacheConfig]: @@ -885,6 +889,7 @@ def _create_kv_cache_manager( max_num_tokens: int, max_beam_width: int, kv_connector_manager: Optional[KvCacheConnectorManager], + is_disagg: bool = False, estimating_kv_cache: bool = False, execution_stream: Optional[torch.cuda.Stream] = None, # Optional overrides for one-model draft case (when model_engine is None) @@ -988,7 +993,7 @@ def _create_kv_cache_manager( # - If layer_mask[i] is True, include layer i # - For layers beyond hybrid_override_pattern, treat them as attention layers pattern_len = len(config.hybrid_override_pattern) - hybrid_layer_mask = [] + full_attention_layer_mask = [] mamba_layer_mask = [] for i, include in enumerate(layer_mask): if i < pattern_len: @@ -999,13 +1004,14 @@ def _create_kv_cache_manager( # Beyond the pattern (e.g., MTP/draft layers), treat as attention-only is_attention = True is_mamba = False - hybrid_layer_mask.append(is_attention and include) + full_attention_layer_mask.append(is_attention and include) mamba_layer_mask.append(is_mamba and include) - num_layers = sum(hybrid_layer_mask) + num_full_attention_layers = sum(full_attention_layer_mask) mamba_num_layers = sum(mamba_layer_mask) else: - num_layers = config.hybrid_override_pattern.count("*") - hybrid_layer_mask = [ + num_full_attention_layers = config.hybrid_override_pattern.count( + "*") + full_attention_layer_mask = [ char == "*" for char in config.hybrid_override_pattern ] mamba_num_layers = config.hybrid_override_pattern.count("M") @@ -1021,9 +1027,9 @@ def _create_kv_cache_manager( from ..speculative.utils import get_num_spec_layers num_spec_layers = get_num_spec_layers(spec_config) if num_spec_layers > 0: - hybrid_layer_mask.extend([True] * num_spec_layers) + full_attention_layer_mask.extend([True] * num_spec_layers) mamba_layer_mask.extend([False] * num_spec_layers) - num_layers += num_spec_layers + num_full_attention_layers += num_spec_layers kv_cache_manager = kv_cache_manager_cls( # mamba cache parameters config.ssm_state_size, @@ -1036,11 +1042,12 @@ def _create_kv_cache_manager( config.torch_dtype, quant_config.mamba_ssm_cache_dtype if quant_config is not None else None, + is_disagg, # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - layer_mask=hybrid_layer_mask, + num_layers=num_full_attention_layers, + layer_mask=full_attention_layer_mask, num_kv_heads=per_layer_num_kv_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -1062,9 +1069,9 @@ def _create_kv_cache_manager( raise NotImplementedError( "Connector manager is not supported for MambaHybridCacheManager." ) - hybrid_layer_mask, mamba_layer_mask = get_qwen3_hybrid_layer_masks( + full_attention_layer_mask, mamba_layer_mask = get_qwen3_hybrid_layer_masks( config) - # For hybrid models, hybrid_layer_mask is always passed as + # For hybrid models, full_attention_layer_mask is always passed as # layer_mask to KVCacheManager, which means get_pp_layers # sees a non-None layer_mask and won't auto-add spec layers. # Extend the masks here to include MTP spec layers (full @@ -1073,9 +1080,9 @@ def _create_kv_cache_manager( from ..speculative.utils import get_num_spec_layers num_spec_layers = get_num_spec_layers(spec_config) if num_spec_layers > 0: - hybrid_layer_mask.extend([True] * num_spec_layers) + full_attention_layer_mask.extend([True] * num_spec_layers) mamba_layer_mask.extend([False] * num_spec_layers) - num_layers = sum(hybrid_layer_mask) + num_full_attention_layers = sum(full_attention_layer_mask) num_mamba_layers = sum(mamba_layer_mask) kv_cache_manager = kv_cache_manager_cls( # mamba cache parameters @@ -1089,11 +1096,12 @@ def _create_kv_cache_manager( config.torch_dtype, quant_config.mamba_ssm_cache_dtype if quant_config is not None else None, + is_disagg, # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - layer_mask=hybrid_layer_mask, + num_layers=num_full_attention_layers, + layer_mask=full_attention_layer_mask, num_kv_heads=per_layer_num_kv_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -1110,7 +1118,9 @@ def _create_kv_cache_manager( # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager is_vswa = is_vswa_enabled(kv_cache_config) binding_model_config = _model_config.get_bindings_model_config( - tokens_per_block=tokens_per_block) if is_vswa else None + tokens_per_block=tokens_per_block, + kv_cache_config=kv_cache_config, + spec_config=spec_config) if is_vswa else None kv_cache_manager = kv_cache_manager_cls( kv_cache_config, @@ -1174,7 +1184,8 @@ def create_py_executor_instance( logger.info( f"max_seq_len={max_seq_len}, max_num_requests={max_num_sequences}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}" ) - + is_disagg = (cache_transceiver_config is not None + and cache_transceiver_config.backend is not None) for key, value in llm_args.extra_resource_managers.items(): if key in resources: raise ValueError( @@ -1198,7 +1209,7 @@ def create_py_executor_instance( ) model_binding_config = model_engine.model.model_config.get_bindings_model_config( - ) + is_disagg=is_disagg) num_experts = _try_infer_num_experts(model_engine.model.model_config) @@ -1394,7 +1405,7 @@ def create_py_executor_instance( # For hybrid models, this has both impl and mamba_impl mamba_cache_manager = None - if isinstance(kv_cache_manager, MambaHybridCacheManager): + if isinstance(kv_cache_manager, BaseMambaCacheManager): mamba_cache_manager = kv_cache_manager kv_cache_transceiver = create_kv_cache_transceiver( diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 9c3b4c37560f..11b7a6160c7c 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -5,6 +5,10 @@ import transformers +def is_hybrid_linear(config): + return is_nemotron_hybrid(config) or is_qwen3_hybrid(config) + + def is_nemotron_hybrid(config): if hasattr(config, "hybrid_override_pattern" ) and config.hybrid_override_pattern is not None and len( diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 375a3d350dbe..615994c1d81e 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -24,12 +26,16 @@ if TYPE_CHECKING: from tensorrt_llm._torch.attention_backend.interface import \ AttentionMetadata + from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import ( BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers) from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests -from tensorrt_llm._utils import prefer_pinned, torch_dtype_to_binding +from tensorrt_llm._utils import (nvtx_range, prefer_pinned, + torch_dtype_to_binding) +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheConnectorManager, LinearAttentionMetadata, LinearCacheType) from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -58,8 +64,55 @@ def use_cpp_mamba_cache_manager() -> bool: return os.environ.get('TRTLLM_USE_CPP_MAMBA', '0') == '1' +class BaseMambaCacheManager(ABC): + """Abstract interface for accessing mamba/recurrent state caches.""" + + @abstractmethod + def get_state_indices(self, *args, **kwargs) -> torch.Tensor: + """Return slot indices of each request. + + Shape: [max_batch_size] + """ + ... + + @abstractmethod + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + """Return conv states for specific layer. + + Shape: [slot_size, conv_dim, d_conv - 1] + """ + ... + + @abstractmethod + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + """Return SSM states for specific layer. + + Shape: [slot_size, num_heads, head_dim, d_state] + """ + ... + + @abstractmethod + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + ... + + @abstractmethod + def is_speculative(self) -> bool: + ... + + @abstractmethod + def mamba_layer_cache( + self, layer_idx: int + ) -> Union['PythonMambaCacheManager.State', + 'PythonMambaCacheManager.SpeculativeState', None]: + ... + + class CppMambaCacheManager(BaseResourceManager): - """C++ backed Mamba cache manager using RnnStateManager bindings.""" + """Mamba state manager backed by the C++ RnnStateManager bindings. + + Manages only mamba states (conv + SSM). Used when TRTLLM_USE_CPP_MAMBA=1. + Supports disaggregated serving. + """ def __init__( self, @@ -165,6 +218,11 @@ def shutdown(self): class PythonMambaCacheManager(BaseResourceManager): + """Pure-Python mamba state manager with speculative decoding support. + + Manages only mamba states (conv + SSM) using PyTorch tensors on GPU. + Supports speculative decoding and disaggregated serving. + """ @dataclass(frozen=True, kw_only=True) class State: @@ -499,7 +557,11 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", conv_states[:, state_indices_d, :] = accepted_conv_state -class MambaCacheManager(BaseResourceManager): +class MambaCacheManager(BaseResourceManager, BaseMambaCacheManager): + """Facade for standalone mamba state management (no KV cache). + + Delegates to CppMambaCacheManager (when TRTLLM_USE_CPP_MAMBA=1) or PythonMambaCacheManager. + """ def __init__( self, @@ -627,7 +689,13 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", self._impl.update_mamba_states(attn_metadata, num_accepted_tokens) -class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): +class MixedMambaHybridCacheManager(KVCacheManager, MambaCacheManager): + """Hybrid cache manager combining separate KVCacheManager and MambaCacheManager. + + Manages KV cache and mamba states in independent pools, with support of + speculative decoding and disaggregated serving. + Does not support block reuse / prefix caching for mamba states. + """ def __init__( self, @@ -709,9 +777,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): MambaCacheManager.prepare_resources(self, scheduled_batch) KVCacheManager.prepare_resources(self, scheduled_batch) - def free_resources(self, request: LlmRequest): + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): MambaCacheManager.free_resources(self, request) - KVCacheManager.free_resources(self, request) + KVCacheManager.free_resources(self, request, pin_on_release) def add_dummy_requests(self, request_ids: List[int], **kwargs): MambaCacheManager.add_dummy_requests(self, request_ids) @@ -732,3 +800,437 @@ def update_mamba_states(self, attn_metadata: "AttentionMetadata", num_accepted_tokens: torch.Tensor): MambaCacheManager.update_mamba_states(self, attn_metadata, num_accepted_tokens) + + +def calc_context_stop_positions(prompt_len: int, + tokens_per_block: int, + mamba_state_cache_interval: int, + save_last_snapshot: bool = False) -> list[int]: + """Compute token positions at which mamba state snapshots should be saved. + + Returns positions spaced by ``mamba_state_cache_interval`` plus the final + prompt length (and optionally the last block-aligned position). + """ + stop_positions = list( + range(mamba_state_cache_interval, prompt_len, + mamba_state_cache_interval)) + last_ckpt = prompt_len // tokens_per_block * tokens_per_block + if save_last_snapshot and (last_ckpt not in stop_positions): + stop_positions.append(last_ckpt) + if prompt_len not in stop_positions: + stop_positions.append(prompt_len) + return stop_positions + + +class CppMambaHybridCacheManager(KVCacheManager, BaseMambaCacheManager): + """Hybrid cache manager storing mamba states inside the KVCacheManager pool. + + Both KV cache blocks and recurrent state blocks are managed by the unified + C++ KVCacheManager, enabling block reuse / prefix caching across attention + and mamba layers. This is the default hybrid manager. + + Disaggregated serving and speculative decoding are not supported yet. + """ + + def __init__( + self, + # mamba cache parameters + mamba_d_state: int, + mamba_d_conv: int, + mamba_num_heads: int, + mamba_n_groups: int, + mamba_head_dim: int, + mamba_num_layers: int, + mamba_layer_mask: List[bool], + mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + *, + num_layers: int, + num_kv_heads: Union[int, List[Optional[int]]], + head_dim: int, + tokens_per_block: int, + # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. + # It's derived from the model's BuildConfig for consistency with the C++ backend. + max_seq_len: int, + max_batch_size: int, + mapping: Mapping, + dtype: DataType = DataType.HALF, + spec_config: Optional["DecodingBaseConfig"] = None, + layer_mask: Optional[List[bool]] = None, + max_num_tokens: int = 8192, + max_beam_width: int = 1, + is_draft: bool = False, + kv_connector_manager: Optional[KvCacheConnectorManager] = None, + enable_indexer_k_cache: bool = False, + indexer_k_cache_quant_block_size: int = 128, + indexer_k_cache_index_head_dim: int = 0, + is_estimating_kv_cache: bool = False, + **kwargs, + ) -> None: + # Derive ssm_state_shape and conv_state_shape from mamba params (same as MambaCacheManager) + tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1 + d_inner = mamba_head_dim * mamba_num_heads + conv_dim = d_inner + 2 * mamba_n_groups * mamba_d_state + nheads = mamba_num_heads + assert nheads % tp_size == 0, "mamba_num_heads must be divisible by tp_size" + assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size" + conv_dim = conv_dim // tp_size + nheads = nheads // tp_size + self.conv_state_shape = [conv_dim, mamba_d_conv - 1] + self.ssm_state_shape = [nheads, mamba_head_dim, mamba_d_state] + self.ssm_state_dtype = mamba_ssm_cache_dtype + self.conv_state_dtype = mamba_cache_dtype + self.ssm_count = math.prod(self.ssm_state_shape) + self.conv_count = math.prod(self.conv_state_shape) + self.ssm_bytes = self.ssm_count * self.ssm_state_dtype.itemsize + self.conv_bytes = self.conv_count * self.conv_state_dtype.itemsize + + total_bytes = self.ssm_bytes + self.conv_bytes + if total_bytes % self.ssm_state_dtype.itemsize != 0: + raise RuntimeError( + f"Total state bytes ({total_bytes}) not divisible by " + f"ssm_state_dtype size ({self.ssm_state_dtype.itemsize})") + if total_bytes % self.conv_state_dtype.itemsize != 0: + raise RuntimeError( + f"Total state bytes ({total_bytes}) not divisible by " + f"conv_state_dtype size ({self.conv_state_dtype.itemsize})") + if self.ssm_bytes % self.conv_state_dtype.itemsize != 0: + raise RuntimeError( + f"SSM state bytes ({self.ssm_bytes}) not divisible by " + f"conv_state_dtype size ({self.conv_state_dtype.itemsize})") + + self.linear_attention_metadata = LinearAttentionMetadata() + self.linear_attention_metadata.cache_type = LinearCacheType.RECURRENT_STATES.value + self.linear_attention_metadata.all_recurrent_states_bytes = self.ssm_bytes + self.conv_bytes + self.linear_attention_metadata.states_snapshot_interval = kv_cache_config.mamba_state_cache_interval + + if kv_cache_config.enable_partial_reuse: + logger.warning( + "Partial reuse is not supported for mamba hybrid models, disabling partial reuse" + ) + kv_cache_config.enable_partial_reuse = False + + full_attention_layer_mask = layer_mask.copy() + + kv_cache_config.max_attention_window = [] + layer_mask = [ + mamba_layer_mask[i] or full_attention_layer_mask[i] + for i in range(len(mamba_layer_mask)) + ] + for i in range(len(layer_mask)): + if layer_mask[i]: + kv_cache_config.max_attention_window.append( + LinearCacheType.RECURRENT_STATES. + value if mamba_layer_mask[i] else max_seq_len) + # pass remaining arguments to super class + super().__init__( + kv_cache_config, + kv_cache_type, + num_layers=mamba_num_layers + num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=dtype, + spec_config=spec_config, + layer_mask=layer_mask, + max_num_tokens=max_num_tokens, + max_beam_width=max_beam_width, + is_draft=is_draft, + kv_connector_manager=kv_connector_manager, + enable_indexer_k_cache=enable_indexer_k_cache, + indexer_k_cache_quant_block_size=indexer_k_cache_quant_block_size, + indexer_k_cache_index_head_dim=indexer_k_cache_index_head_dim, + is_estimating_kv_cache=is_estimating_kv_cache, + linear_attention_metadata=self.linear_attention_metadata, + ) + # respect base class's pp sharding + self.mamba_pp_layers = [] + for layer_idx in self.pp_layers: + if mamba_layer_mask[layer_idx]: + self.mamba_pp_layers.append(layer_idx) + self.local_num_mamba_layers = len(self.mamba_pp_layers) + + assert self.local_num_mamba_layers > 0, "At least one mamba layer is required" + self.mamba_layer_offsets = {} + for idx, layer_id in enumerate(self.mamba_pp_layers): + self.mamba_layer_offsets[layer_id] = idx + + self.host_block_offsets = torch.zeros([ + self.impl.num_pools, self.max_batch_size, 2, self.max_blocks_per_seq + ], + dtype=torch.int32, + device="cpu") + self.requests = [] + self.recurrent_states_pool_index = self.kv_cache_pool_mapping[ + self.layer_offsets[self.mamba_pp_layers[0]]][0] + self.cuda_state_indices = torch.zeros([self.max_batch_size], + dtype=torch.int32, + device="cuda") + self.kv_cache_config = kv_cache_config + + self._setup_states_views() + + self.is_estimating_kv_cache = is_estimating_kv_cache + + def shutdown(self): + # Release tensor views into the pool before the pool memory is freed, + # so their deleters don't see stale pointers. + self.all_ssm_states = None + self.all_conv_states = None + super().shutdown() + + def add_dummy_requests( + self, + request_ids: List[int], + # Note that token_nums should be past_kv_len + input_len (without + # spec decoding). The draft tokens will be added in this function, + # so we don't need to take care of it in the caller. When preparing + # token_nums, we should not take the draft tokens into account, so + # don't use the kv_cache_manager.max_seq_len, which includes both + # extra tokens and draft tokens. + token_nums: Optional[List[int]] = None, + is_gen: bool = False, + prepare_resource: bool = True, + max_num_draft_tokens: int = 0, + use_mrope: bool = False, + max_beam_width: int = 1, + # For capturable drafting loops. During normal inference, the draft model always + # has enough KV cache space to fit all of our draft tokens. During warmup, however, + # we need to make the KV cache manager aware that multiple autoregressive steps will + # occur. + num_extra_decoding_steps: int = 0, + draft_kv_cache_manager: Optional[KVCacheManager] = None, + ) -> List[LlmRequest]: + requests = super().add_dummy_requests(request_ids, token_nums, is_gen, + prepare_resource, + max_num_draft_tokens, use_mrope, + max_beam_width, + num_extra_decoding_steps, + draft_kv_cache_manager) + self.requests.extend(requests) + self._setup_state_indices() + return requests + + def update_resources(self, + scheduled_batch: ScheduledRequests, + attn_metadata: "AttentionMetadata" = None, + kv_cache_dtype_byte_size: float = None): + super().update_resources(scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + + @nvtx_range("hybrid_prepare_resources") + def _prepare_resources(self, scheduled_batch: ScheduledRequests): + self.requests = scheduled_batch.context_requests + \ + scheduled_batch.generation_requests + for req in self.requests: + self.impl.copy_linear_attention_block(req) + self.impl.refresh_blocks() + self._setup_state_indices() + + def prepare_resources(self, scheduled_batch: ScheduledRequests): + super().prepare_resources(scheduled_batch) + self._prepare_resources(scheduled_batch) + + def is_speculative(self) -> bool: + # Not implemented yet. + return False + + def update_mamba_states(self, attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor): + raise NotImplementedError( + "CppMambaHybridCacheManager does not support speculative decoding. " + "Use MixedMambaHybridCacheManager (spec_config or TRTLLM_USE_CPP_MAMBA=1) instead." + ) + + def get_ssm_states(self, layer_idx: int) -> torch.Tensor: + return self.all_ssm_states[self.mamba_layer_offsets[layer_idx]] + + def get_conv_states(self, layer_idx: int) -> torch.Tensor: + return self.all_conv_states[self.mamba_layer_offsets[layer_idx]] + + def mamba_layer_cache( + self, layer_idx: int) -> Union[PythonMambaCacheManager.State, None]: + ret = PythonMambaCacheManager.State( + conv=self.get_conv_states(layer_idx), + temporal=self.get_ssm_states(layer_idx)) + return ret + + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): + if request in self.requests: + self.requests.remove(request) + super().free_resources(request, pin_on_release) + + def _setup_state_indices(self) -> None: + block_indices = [] + for req in self.requests: + if req.is_context_finished: + next_step = self.get_num_tokens(req) - 1 + elif self.kv_cache_config.enable_block_reuse: + next_step = (req.context_current_position - 1 + + req.context_chunk_size) + else: + next_step = req.prompt_len - 1 + block_indices.append(next_step // self.tokens_per_block) + self.impl.copy_batch_block_offsets( + self.host_block_offsets, + [req.py_request_id for req in self.requests], 1, 0) + host_block_offsets = torch.zeros([len(self.requests)], + dtype=torch.int32, + device="cpu") + for i in range(len(self.requests)): + # With layer-first pool layout, setOffsets produces the block index directly + # (no longer multiplied by num_local_mamba_layers) + value = self.host_block_offsets[self.recurrent_states_pool_index, i, + 0, block_indices[i]] + max_blocks = self.blocks_per_window[ + LinearCacheType.RECURRENT_STATES.value][0] + if value < 0 or value >= max_blocks: + raise RuntimeError( + f"Invalid recurrent state block index {value} " + f"(expected 0 <= index < {max_blocks}) for request {i}") + host_block_offsets[i] = value + + torch.fill_(self.cuda_state_indices, 0) + self.cuda_state_indices[:len(self.requests)] = host_block_offsets.cuda() + self._host_state_indices = host_block_offsets.clone() + + def get_state_indices( + self, + request_ids: Optional[List[int]] = None, + is_padding: Optional[List[bool]] = None) -> torch.Tensor: + return self.cuda_state_indices + + def calc_next_context_chunk_size(self, request: LlmRequest) -> int: + """Compute the next prefill chunk size for a context request when block reuse is enabled. + + When kv_cache_config.enable_block_reuse is True, context prefill must stop exactly at + the positions returned by calc_context_stop_positions (mamba_state_cache_interval boundaries + and block boundaries). This returns the chunk_size to use for the next prefill step so + that the next stop position is not exceeded. + + Args: + request: Context request with prompt_len and context_current_position set. + + Returns: + Number of tokens to prefill in the next step (0 if context is already complete). + """ + prompt_len = request.prompt_len + current = request.context_current_position + if current >= prompt_len: + return 0 + if not self.kv_cache_config.enable_block_reuse: + assert current == 0, f"Expected context_current_position to be 0 when block reuse is disabled, but got {current}" + return prompt_len - current + step = self.linear_attention_metadata.states_snapshot_interval + stop_positions = calc_context_stop_positions(prompt_len, + self.tokens_per_block, + step) + stop_positions = sorted(set(stop_positions)) + for pos in stop_positions: + if pos > current: + return pos - current + return prompt_len - current + + def _setup_states_views(self) -> None: + # Pool layout: {numLocalLayers, numBlocks, ssm_bytes + conv_bytes} (as uint8) + pool: torch.Tensor = self.impl.get_recurrent_states_pool().view( + torch.uint8).reshape(self.local_num_mamba_layers, -1, + self.ssm_bytes + self.conv_bytes) + num_blocks_in_pool = pool.shape[1] + self.all_ssm_states = pool[:, :, :self.ssm_bytes].view( + self.ssm_state_dtype).view( + [self.local_num_mamba_layers, num_blocks_in_pool] + + self.ssm_state_shape) + self.all_conv_states = pool[:, :, self.ssm_bytes:self.ssm_bytes + + self.conv_bytes].view( + self.conv_state_dtype).view([ + self.local_num_mamba_layers, + num_blocks_in_pool + ] + self.conv_state_shape) + + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.ssm_state_dtype + + +class _MambaHybridCacheManagerMeta(type): + """Metaclass that enables isinstance/issubclass checks against + MambaHybridCacheManager for both Mixed and Cpp implementations.""" + + def __instancecheck__(cls, instance): + if cls is MambaHybridCacheManager: + return isinstance( + instance, + (MixedMambaHybridCacheManager, CppMambaHybridCacheManager)) + return super().__instancecheck__(instance) + + def __subclasscheck__(cls, subclass): + if cls is MambaHybridCacheManager: + return issubclass( + subclass, + (MixedMambaHybridCacheManager, CppMambaHybridCacheManager)) + return super().__subclasscheck__(subclass) + + def __getattr__(cls, name): + """Forward class-level attribute access (e.g. static methods) to + KVCacheManager. Add attributes here as needed.""" + return getattr(KVCacheManager, name) + + +class MambaHybridCacheManager(metaclass=_MambaHybridCacheManagerMeta): + """Factory that selects the appropriate hybrid cache manager. + + Selection logic: + - Speculative decoding or TRTLLM_USE_CPP_MAMBA=1 -> MixedMambaHybridCacheManager + - Otherwise (default) -> CppMambaHybridCacheManager + """ + + def __new__( + cls, + # mamba cache parameters + mamba_d_state: int, + mamba_d_conv: int, + mamba_num_heads: int, + mamba_n_groups: int, + mamba_head_dim: int, + mamba_num_layers: int, + mamba_layer_mask: List[bool], + mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + is_disagg: bool, + # kv cache parameters + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + **kwargs, + ): + positional_args = ( + mamba_d_state, + mamba_d_conv, + mamba_num_heads, + mamba_n_groups, + mamba_head_dim, + mamba_num_layers, + mamba_layer_mask, + mamba_cache_dtype, + mamba_ssm_cache_dtype, + kv_cache_config, + kv_cache_type, + ) + + spec_config = kwargs.get('spec_config', None) + use_v1 = (is_disagg or use_cpp_mamba_cache_manager() + or spec_config is not None) + + if use_v1: + logger.info( + "Using MixedMambaHybridCacheManager for hybrid cache management" + ) + return MixedMambaHybridCacheManager(*positional_args, **kwargs) + else: + logger.info( + "Using CppMambaHybridCacheManager for hybrid cache management") + return CppMambaHybridCacheManager(*positional_args, **kwargs) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 79bebc92ff79..938044bd7676 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -36,7 +36,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, create_py_executor_instance, instantiate_sampler, is_mla, validate_feature_combination) -from .config_utils import is_nemotron_hybrid, is_qwen3_hybrid +from .config_utils import is_hybrid_linear from .connectors.kv_cache_connector import KvCacheConnectorManager from .dwdp import DwdpManager from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder @@ -550,10 +550,12 @@ def drafting_loop_wrapper(model): cache_transceiver_config.max_tokens_in_buffer = net_max_seq_len config = model_engine.model.model_config.pretrained_config - if (is_nemotron_hybrid(config) - or is_qwen3_hybrid(config)) and kv_cache_config.enable_block_reuse: + if is_hybrid_linear(config) and kv_cache_config.enable_block_reuse and ( + spec_config is not None or cache_transceiver_config is not None + and cache_transceiver_config.backend is not None): logger.warning( - "Disabling block reuse for MambaHybridCacheManager-based models") + "Disabling block reuse for MambaHybridCacheManager-based models when MTP or disagg is enabled" + ) kv_cache_config.enable_block_reuse = False _set_model_engines_cache_reuse([model_engine, draft_model_engine], False) @@ -610,6 +612,10 @@ def drafting_loop_wrapper(model): else: ctx_chunk_config = None + if kv_cache_config.enable_block_reuse and is_hybrid_linear(config): + ctx_chunk_config = (ContextChunkingPolicy.FORCE_CHUNK, + kv_cache_config.mamba_state_cache_interval) + guided_decoder: Optional[GuidedDecoder] = None if guided_decoding_config is not None: with allocation_scope(ExecutorMemoryType.GUIDED_DECODER): @@ -731,7 +737,7 @@ def drafting_loop_wrapper(model): is_disagg = (cache_transceiver_config is not None and cache_transceiver_config.backend is not None) - is_hybrid = is_nemotron_hybrid(config) or is_qwen3_hybrid(config) + is_hybrid = is_hybrid_linear(config) if is_disagg and is_hybrid: if cache_transceiver_config.transceiver_runtime != "PYTHON" or os.environ.get( @@ -765,6 +771,7 @@ def drafting_loop_wrapper(model): execution_stream=execution_stream, draft_config=draft_config, skip_est=skip_est, + is_disagg=is_disagg, ) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3c483a638c8c..d7b744c059df 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -16,11 +16,13 @@ from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, get_size_in_bytes, mpi_comm, mpi_disabled, prefer_pinned, torch_comm) -from tensorrt_llm.bindings.internal.batch_manager import KvCacheStats +from tensorrt_llm.bindings.internal.batch_manager import ( + KvCacheStats, LinearAttentionMetadata, LinearCacheType) from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( IndexMapper, copy_batch_block_offsets_to_device) from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig -from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, + PybindMirror) from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.runtime import ModelConfig as ModelConfigPython @@ -287,6 +289,7 @@ def __init__( indexer_k_cache_index_head_dim: int = 0, is_estimating_kv_cache: bool = False, execution_stream: Optional[torch.cuda.Stream] = None, + linear_attention_metadata: Optional[LinearAttentionMetadata] = None, **kwargs, ) -> None: self.mapping = mapping @@ -361,11 +364,30 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], self.max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 self.max_total_draft_tokens = (spec_config.tokens_per_gen_step - 1) if spec_config is not None else 0 + self.linear_attention_metadata = linear_attention_metadata # Determine max_attention_window_vec if kv_cache_config.max_attention_window is None: # Use max_seq_len as default max_attention_window self.max_attention_window_vec = [max_seq_len] + elif len(kv_cache_config.max_attention_window) == num_layers: + # we need to shard the max_attention_window according to the layer_mask and pp_layers + self.max_attention_window_vec = [] + if layer_mask is not None: + global_enabled_layers = [ + layer_idx for layer_idx in range(len(layer_mask)) + if layer_mask[layer_idx] + ] + else: + global_enabled_layers = list(range(num_layers)) + pp_rank_offset = global_enabled_layers.index(self.pp_layers[0]) + for layer_idx in self.pp_layers: + if layer_mask is not None and not layer_mask[layer_idx]: + continue + window_size = kv_cache_config.max_attention_window[ + pp_rank_offset + self.layer_offsets[layer_idx]] + window_size = min(window_size, max_seq_len) + self.max_attention_window_vec.append(window_size) else: self.max_attention_window_vec = kv_cache_config.max_attention_window.copy( ) # Make a copy to avoid modifying original @@ -380,8 +402,12 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], if kv_cache_config.sink_token_length is not None else 0) - # Determine if this is VSWA (Variable Sliding Window Attention) - self.is_vswa = len(set(self.max_attention_window_vec)) > 1 + # Determine if this is VSWA (Variable Sliding Window Attention). + # The `w > 0` check excludes LinearCacheType.RECURRENT_STATES sentinel + # values (negative) used by hybrid linear attention models. + self.is_vswa = len(set(self.max_attention_window_vec)) > 1 and all( + w > 0 for w in self.max_attention_window_vec) + self.is_linear_attention = linear_attention_metadata is not None # Calculate kv cache blocks for each window size # FIXME: flashinfer.py accesses kv_cache_manager.blocks_in_primary_pool @@ -406,16 +432,21 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], (self.blocks_in_primary_pool, self.blocks_in_secondary_pool) for window_size in set(self.max_attention_window_vec) } + if self.is_linear_attention: + if kv_cache_config.enable_block_reuse: + max_snapshots = max( + kv_cache_config.max_tokens // + linear_attention_metadata.states_snapshot_interval, + self.max_batch_size) + else: + max_snapshots = self.max_batch_size + blocks_per_window[LinearCacheType.RECURRENT_STATES.value] = ( + int(max_snapshots), 0) logger.info( f"[kv cache manager] Primary/secondary blocks for window sizes set to {blocks_per_window} for estimation dry run" ) else: - if self.is_vswa: - # VSWA case: use C++ implementation for variable window sizes - if model_config is None: - raise ValueError( - "model_config is required for VSWA (Variable Sliding Window Attention)" - ) + if self.is_vswa or self.is_linear_attention: assert isinstance( kv_cache_config, KvCacheConfig ), "calculate_max_num_blocks_for_vswa only accepts KvCacheConfig" @@ -514,7 +545,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'dtype': dtype, 'sink_token_length': sink_token_length, 'stream': self._stream.cuda_stream, # Pass to BufferManager - 'max_sequence_length': max_seq_len, + 'max_sequence_length': self.max_seq_len, 'enable_block_reuse': kv_cache_config.enable_block_reuse, 'cache_type': kv_cache_type, 'enable_partial_reuse': kv_cache_config.enable_partial_reuse, @@ -523,7 +554,8 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'enable_indexer_k_cache': enable_indexer_k_cache, 'indexer_k_cache_quant_block_size': indexer_k_cache_quant_block_size, - 'indexer_k_cache_index_head_dim': indexer_k_cache_index_head_dim + 'indexer_k_cache_index_head_dim': indexer_k_cache_index_head_dim, + 'linear_attention_metadata': linear_attention_metadata } if self.event_buffer_max_size > 0: @@ -567,6 +599,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], dtype=torch.int32, pin_memory=prefer_pinned(), device='cpu') + self.blocks_per_window = blocks_per_window def probe_prefix_match_length(self, input_tokens, lora_task_id=None): """Probe the KV cache radix tree for prefix match length. @@ -604,6 +637,10 @@ def shutdown(self): def get_max_resource_count(self) -> int: return self.impl.max_num_blocks + def get_num_tokens(self, request: LlmRequest) -> int: + # LlmRequest.get_num_tokens is out of sync with GenerationRequest when overlap scheduler is enabled. + return self.impl.get_token_count(request.py_request_id) + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: # TODO: the C++ implementation of this method can be used, but the # Python and C++ schedulers currently do not agree on what "needed @@ -624,7 +661,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() - # Collect first-chunk requests eligible for batch add_sequence. + # Collect first-chunk requests eligible for add_sequence_batch. # When block reuse is enabled, addSequenceBatch uses a two-phase # claim-then-onboard strategy that prevents host offloading from # evicting reusable blocks in the radix tree. @@ -672,7 +709,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.kv_connector_manager.update_state_after_alloc( req, block_ids) - # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in `add_sequence` due to KV cache reuse, so we rebuild the context request lists here. + # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in + # `add_sequence_batch` due to KV cache reuse, so we rebuild the context request lists here. scheduled_batch.reset_context_requests() for req in scheduled_batch.generation_requests: @@ -790,7 +828,7 @@ def add_dummy_requests( "mrope_position_deltas"] = dummy_mrope_position_deltas requests.append(req) - # Batch add_sequence for all dummy requests, then add extra tokens. + # Use add_sequence_batch for all dummy requests, then add extra tokens. # This must happen before is_gen state modifications below, which may # set prompt_len to 0 and trigger assertion in setPrepopulatedPromptLen. if batch_request_infos: @@ -809,7 +847,7 @@ def add_dummy_requests( for _ in range(self.num_extra_kv_tokens): draft_kv_cache_manager.impl.add_token(req_id) - # Set is_gen state after batch add_sequence to avoid modifying + # Set is_gen state after add_sequence_batch to avoid modifying # prompt_len before the C++ side reads it. if is_gen: for i, req in enumerate(requests): @@ -906,7 +944,11 @@ def _resolve_num_attention_layers( return max(num_layers, 1) # provide at least 1 layer to prevent division by zero cache size return max( - len(mapping.pp_layers(model_config.get_num_attention_layers())), 1) + # when is_disagg=True, for hybrid models it returns the number of full attention layers. + len( + mapping.pp_layers( + model_config.get_num_attention_layers(is_disagg=True))), + 1) # TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic @staticmethod @@ -1102,9 +1144,9 @@ def get_batch_cache_indices( return result def get_num_free_blocks(self) -> int: - if self.is_vswa: + if self.is_vswa or self.is_linear_attention: logger.info( - f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}" + f"For {'linear attention' if self.is_linear_attention else 'VSWA'} case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}" ) return min(self.impl.get_kv_cache_stats(). num_free_blocks_per_window_size.values()) @@ -1399,7 +1441,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: def calculate_max_num_blocks_for_vswa( self, kv_cache_config: KvCacheConfig, - model_config: ModelConfigCpp, + model_config: Optional[ModelConfigCpp], extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ Currently, this function is added to support *ONLY* VSWA. @@ -1425,7 +1467,6 @@ def calculate_max_num_blocks_for_vswa( # VSWA on Torch backend has not supported the cross attention. is_cross_attention = False # check model config - assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" # Construct WorldConfig from self.mapping world_config_cpp = WorldConfig( @@ -1449,19 +1490,45 @@ def calculate_max_num_blocks_for_vswa( f"secondary_pool_memory_bytes is set to {self._secondary_pool_memory_bytes/1024**3}GB" ) - # Adjust the window sizes to fit the memory if even a single sequence - # cannot fit in the memory. - window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( - window_size_to_layers=window_size_to_layers, - max_attention_window_vec=self.max_attention_window_vec, - model_config=model_config, - kv_cache_config=kv_cache_config, - pool_memory_bytes=self._primary_pool_memory_bytes, - kv_factor=self.kv_factor, - dtype=self.dtype, - is_cross_attention=is_cross_attention, - ) - self.max_attention_window_vec = max_attention_window_vec + if self.is_linear_attention: + blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks( + config=PybindMirror.maybe_to_pybind(kv_cache_config), + dtype=self.dtype, + num_kv_heads_per_layer=list(self.num_kv_heads_per_layer), + size_per_head=self.head_dim, + tokens_per_block=self.tokens_per_block, + world_config=world_config_cpp, + window_size_to_layers=window_size_to_layers, + allotted_primary_mem_bytes=self._primary_pool_memory_bytes, + allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes, + extra_cost_memory=extra_cost_memory, + kv_factor=self.kv_factor, + max_batch_size=self.max_batch_size, + linear_attention_metadata=PybindMirror.maybe_to_pybind( + self.linear_attention_metadata), + ) + return blocks_per_window + + # VSWA case: use C++ implementation for variable window sizes + if model_config is None: + raise ValueError( + "model_config is required for VSWA (Variable Sliding Window Attention)" + ) + assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" + if self.is_vswa: + # Adjust the window sizes to fit the memory if even a single sequence + # cannot fit in the memory. + window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( + window_size_to_layers=window_size_to_layers, + max_attention_window_vec=self.max_attention_window_vec, + model_config=model_config, + kv_cache_config=kv_cache_config, + pool_memory_bytes=self._primary_pool_memory_bytes, + kv_factor=self.kv_factor, + dtype=self.dtype, + is_cross_attention=is_cross_attention, + ) + self.max_attention_window_vec = max_attention_window_vec def calculate_cache_size_per_token(layers: Set[int]) -> int: # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index a9317efa4695..f2afdad53f12 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -27,6 +27,7 @@ except ImportError: PlacementGroup = None +from tensorrt_llm.bindings.internal.batch_manager import LinearCacheType from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -2340,7 +2341,7 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): description= "The maximum number of tokens that should be stored in the KV cache. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used." ) - max_attention_window: Optional[List[PositiveInt]] = Field( + max_attention_window: Optional[List[int]] = Field( default=None, min_length=1, description= @@ -2442,6 +2443,12 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): tokens_per_block: int = Field(default=32, description="The number of tokens per block.") + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. + mamba_state_cache_interval: PositiveInt = Field( + default=256, + description= + "The number of tokens between cache steps in the Mamba prefix cache.") + use_kv_cache_manager_v2: bool = Field( default=False, status="prototype", @@ -2520,9 +2527,9 @@ def validate_max_attention_window(cls, v: Optional[List[int]]): raise ValueError( "kv_cache_config.max_attention_window must contain only integers" ) - if i <= 0: + if i <= 0 and i not in [LinearCacheType.RECURRENT_STATES.value]: raise ValueError( - "kv_cache_config.max_attention_window values must be positive" + "kv_cache_config.max_attention_window values must be positive or LinearCacheType.RECURRENT_STATES.value" ) return v diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py index f2c4beb1fdc7..db6bbbce7a6d 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -828,6 +828,7 @@ def create_kv_cache_manager( mamba_layer_mask, config.torch_dtype, model_config.quant_config.mamba_ssm_cache_dtype, + False, # is_disagg # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, @@ -864,6 +865,7 @@ def create_kv_cache_manager( mamba_layer_mask, config.torch_dtype, model_config.quant_config.mamba_ssm_cache_dtype, + False, # is_disagg # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 920bcc0dfed5..5e4bfb0c49e1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5638,27 +5638,35 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, task.evaluate(llm) @skip_pre_blackwell - @pytest.mark.skip_less_device(4) @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM"], ids=["cutlass", "trtllm"]) @pytest.mark.parametrize( - "tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp", [ - (1, 1, 1, True, True, False), - (4, 1, 1, True, True, False), - (4, 1, 4, True, True, True), - (4, 1, 4, True, True, False), - (4, 1, 4, False, False, False), + "tp_size,pp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp,enable_block_reuse", + [ + (1, 1, 1, True, True, False, True), + (1, 1, 1, True, True, False, False), + (4, 1, 1, True, True, False, False), + (4, 1, 4, True, True, True, False), + (4, 1, 4, True, True, False, False), + (4, 1, 4, False, False, False, False), ], ids=[ - "tp1", "tp4ep1", "tp4ep4_adp_on", "tp4ep4_adp_off", - "no_cuda_graph_overlap" + "tp1_block_reuse", "tp1", "tp4ep1", "tp4ep4_adp_on", + "tp4ep4_adp_off", "no_cuda_graph_overlap" ]) def test_nvfp4(self, moe_backend, tp_size, pp_size, ep_size, cuda_graph, - overlap_scheduler, attention_dp, mocker): + overlap_scheduler, attention_dp, enable_block_reuse, mocker): + gpu_needed = max(tp_size, ep_size) * pp_size + if get_device_count() < gpu_needed: + pytest.skip( + f"Device count {get_device_count()} is less than required {gpu_needed}" + ) model_path = f"{self.MODEL_PATH}/qwen3-next-80b-instruct-nvfp4-ptq-fp8kv" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, - enable_block_reuse=False) + enable_block_reuse=enable_block_reuse) + if enable_block_reuse: + kv_cache_config.mamba_state_cache_interval = 256 pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig( max_batch_size=512, enable_padding=False) @@ -5772,7 +5780,8 @@ def test_bf16(self): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) - def test_fp8(self): + @parametrize_with_ids("enable_block_reuse", [False, True]) + def test_fp8(self, enable_block_reuse): model_dir = f"{self.MODEL_PATH}-FP8" # Model is being added to CI. Skip at the moment. if not os.path.exists(model_dir): @@ -5780,7 +5789,7 @@ def test_fp8(self): world_size = 1 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, - enable_block_reuse=False) + enable_block_reuse=enable_block_reuse) moe_config = MoeConfig(backend='DEEPGEMM') with LLM(model_dir, @@ -5809,29 +5818,31 @@ class TestQwen3_5_397B_A17B(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(4) @pytest.mark.parametrize( - "tp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp, moe_backend", + "tp_size,ep_size,cuda_graph,overlap_scheduler,attention_dp,moe_backend,enable_block_reuse", [ - (4, 4, True, True, False, "CUTEDSL"), - (4, 4, True, True, True, "CUTEDSL"), - (4, 4, True, True, False, "TRTLLM"), - (4, 4, True, True, True, "TRTLLM"), + (4, 4, True, True, False, "CUTEDSL", False), + (4, 4, True, True, True, "CUTEDSL", False), + (4, 4, True, True, False, "TRTLLM", False), + (4, 4, True, True, True, "TRTLLM", False), + (4, 4, True, True, False, "CUTEDSL", True), ], ids=[ "tep4_cutedsl", "adp4_cutedsl", "tep4_trtllm", "adp4_trtllm", + "tep4_block_reuse", ], ) def test_nvfp4(self, tp_size, ep_size, cuda_graph, overlap_scheduler, - attention_dp, moe_backend, mocker): + attention_dp, moe_backend, enable_block_reuse, mocker): model_path = f"{llm_models_root()}/Qwen3.5-397B-A17B-NVFP4" if not os.path.exists(model_path): pytest.skip(f"Model directory {model_path} does not exist") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9, - enable_block_reuse=False) + enable_block_reuse=enable_block_reuse) pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig( max_batch_size=32, enable_padding=False) @@ -6432,6 +6443,49 @@ def test_nvfp4_8gpus(self, attention_dp, moe_backend): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @skip_pre_blackwell + @pytest.mark.parametrize( + "tp_size, ep_size, mamba_state_cache_interval, attention_dp", + [ + (1, 1, 256, False), + (4, 1, 256, False), + (4, 4, 512, False), + (4, 4, 256, True), + ], + ids=["TP1", "TP4", "TEP4", "TEP4_ADP"], + ) + def test_nvfp4_4gpus_block_reuse(self, tp_size, ep_size, + mamba_state_cache_interval, attention_dp): + gpu_needed = max(tp_size, ep_size) + if get_device_count() < gpu_needed: + pytest.skip( + f"Device count {get_device_count()} is less than required {gpu_needed}" + ) + with LLM( + f"{llm_models_root()}/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4", + kv_cache_config=KvCacheConfig( + enable_block_reuse=True, + mamba_ssm_cache_dtype="float16", + mamba_state_cache_interval=mamba_state_cache_interval, + free_gpu_memory_fraction=0.8, + ), + max_batch_size=32, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=ep_size, + pipeline_parallel_size=1, + enable_attention_dp=attention_dp, + cuda_graph_config=CudaGraphConfig(max_batch_size=32, + enable_padding=True), + disable_overlap_scheduler=False, + moe_config=MoeConfig(backend="TRTLLM"), + ) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @skip_pre_blackwell @pytest.mark.skip_less_mpi_world_size(8) @pytest.mark.parametrize( diff --git a/tests/integration/defs/perf/host_perf/test_module_resource_manager.py b/tests/integration/defs/perf/host_perf/test_module_resource_manager.py index 434994c6469e..fd51ef6f42f6 100644 --- a/tests/integration/defs/perf/host_perf/test_module_resource_manager.py +++ b/tests/integration/defs/perf/host_perf/test_module_resource_manager.py @@ -129,7 +129,7 @@ def kv_cache_manager(): """Create a KVCacheManager with a real C++ backend. Uses small dimensions to minimize GPU memory (~1 GB) while still - exercising the real add_sequence/add_token/refresh_blocks code path. + exercising the real add_sequence_batch/add_token/refresh_blocks code path. """ kv_cache_config = KvCacheConfig( max_tokens=MAX_TOKENS, @@ -182,7 +182,7 @@ def test_kv_cache_prepare_generation(kv_cache_manager): req_id = base_id + i req = make_request(request_id=req_id, seq_slot=i, prompt_len=PROMPT_LEN) # Add sequence to KV cache (context phase) - kv_cache_manager.impl.add_sequence(req_id, PROMPT_LEN, 1, req) + kv_cache_manager.impl.add_sequence_batch([(req_id, PROMPT_LEN, 1)], [req]) req.state = LlmRequestState.GENERATION_IN_PROGRESS requests.append(req) @@ -204,7 +204,7 @@ def test_kv_cache_prepare_generation(kv_cache_manager): def test_kv_cache_prepare_context(kv_cache_manager): """Benchmark the context (new request) allocation path at BS=8. - When a new request enters the system, add_sequence is called to + When a new request enters the system, add_sequence_batch is called to allocate initial KV cache blocks for the prompt. This is more expensive than add_token but happens only once per request. @@ -222,11 +222,13 @@ def test_kv_cache_prepare_context(kv_cache_manager): # Warmup for warmup_iter in range(20): reqs = [] + request_infos = [] for i in range(batch_size): req_id = base_id + warmup_iter * batch_size + i req = make_request(request_id=req_id, seq_slot=i, prompt_len=PROMPT_LEN) - kv_cache_manager.impl.add_sequence(req_id, PROMPT_LEN, 1, req) reqs.append(req) + request_infos.append((req_id, PROMPT_LEN, 1)) + kv_cache_manager.impl.add_sequence_batch(request_infos, reqs) kv_cache_manager.impl.refresh_blocks() for req in reqs: kv_cache_manager.impl.remove_sequence(req.request_id, req, False) @@ -234,14 +236,15 @@ def test_kv_cache_prepare_context(kv_cache_manager): # Benchmark for bench_iter in range(num_iterations): reqs = [] + request_infos = [] for i in range(batch_size): req_id = base_id + (20 + bench_iter) * batch_size + i req = make_request(request_id=req_id, seq_slot=i, prompt_len=PROMPT_LEN) reqs.append(req) + request_infos.append((req.request_id, PROMPT_LEN, 1)) start = time.perf_counter() - for req in reqs: - kv_cache_manager.impl.add_sequence(req.request_id, PROMPT_LEN, 1, req) + kv_cache_manager.impl.add_sequence_batch(request_infos, reqs) kv_cache_manager.impl.refresh_blocks() end = time.perf_counter() diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index ded385f71c37..7cdbe6b3c558 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -726,7 +726,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_sof accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention_4gpus[target_sparsity_0.9-fp8kv=True] accuracy/test_llm_api_pytorch.py::TestQwen3_4B::test_eagle3 accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[adp4_cutedsl] accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[adp4_trtllm] accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[tep4_cutedsl] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 1a4d2df77398..096546a89feb 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -69,8 +69,9 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=True] - - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 - - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1_block_reuse-cutlass] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index f6fae26ac3e1..05dfe49c9ef2 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -22,6 +22,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpu_mtp_ar_custom_op TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[adp4_cutedsl] - accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[adp4_trtllm] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_397B_A17B::test_nvfp4[tep4_block_reuse] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_static_eplb[moe_backend=CUTLASS] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_static_eplb[moe_backend=TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 26b61bea4a5e..da97b085094b 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -55,6 +55,8 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=CUTLASS] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_online_eplb[moe_backend=TRTLLM] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_online_eplb[moe_backend=CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_block_reuse[TEP4] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_4gpus_block_reuse[TEP4_ADP] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 33ab21e5fdde..cdcbb22b88d1 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -268,7 +268,8 @@ test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3[Llama-3.1-8b-Instr unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend[act=Relu2-e60_k4_h2048_i1408-seq=8-dtype=torch.bfloat16-backend=TRTLLM-quant=NVFP4-routing=Renormalize] SKIP (https://nvbugs/5989912) accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewise_cuda_graph[mtp3_fp8kv_chunked] SKIP (https://nvbugs/5989920) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=False-attn_backend=TRTLLM] SKIP (https://nvbugs/5997547) -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=False] SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[enable_block_reuse=True] SKIP (https://nvbugs/6004530) accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[latency] SKIP (https://nvbugs/6012526) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] SKIP (https://nvbugs/6018046) test_fmha.py::test_fmha SKIP (https://nvbugs/6018058) @@ -428,3 +429,4 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[v2_kv_cache-trtl disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] SKIP (https://nvbugs/6011317) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5596343) unittest/_torch/speculative/test_eagle3.py::test_llama_eagle3_dynamic_tree[True-False] SKIP (https://nvbugs/6113021) +perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_grace_blackwell-r1_fp4_v2_tep4_mtp3_1k1k] SKIP (https://nvbugs/6114727) diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index 0fde886345df..a634175edc6e 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -639,7 +639,8 @@ def yarn_get_mscale(scale=1, mscale=1): assert success, f"Failed to resume KV cache for request {req_id}" kv_cache.capacity = ctx_len else: - kv_cache_manager.impl.add_sequence(req_id, ctx_len, beam_width, req) + kv_cache_manager.impl.add_sequence_batch( + [(req_id, ctx_len, beam_width)], [req]) attn_metadata = AttentionCls.Metadata( seq_lens=torch.tensor(context_sequence_lengths, dtype=torch.int), request_ids=list(range(len(context_sequence_lengths))), diff --git a/tests/unittest/_torch/executor/test_resource_manager.py b/tests/unittest/_torch/executor/test_resource_manager.py index 716eef9f0441..9dca6497110b 100644 --- a/tests/unittest/_torch/executor/test_resource_manager.py +++ b/tests/unittest/_torch/executor/test_resource_manager.py @@ -751,8 +751,8 @@ def test_kv_cache_reset_reuse_state(self): # First request: Add sequence and store blocks for reuse req1 = self.create_llm_request(0, [1, 2, 3, 4, 5]) - kv_cache_manager.impl.add_sequence(req1.py_request_id, req1.prompt_len, - 1, req1) + kv_cache_manager.impl.add_sequence_batch( + [(req1.py_request_id, req1.prompt_len, 1)], [req1]) stats_initial = kv_cache_manager.get_kv_cache_stats() initial_reused_blocks = stats_initial.reused_blocks @@ -762,8 +762,8 @@ def test_kv_cache_reset_reuse_state(self): # Second request with same tokens - should reuse blocks from the reuse tree req2 = self.create_llm_request(1, [1, 2, 3, 4, 5]) - kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len, - 1, req2) + kv_cache_manager.impl.add_sequence_batch( + [(req2.py_request_id, req2.prompt_len, 1)], [req2]) stats_after_reuse = kv_cache_manager.get_kv_cache_stats() self.assertGreater( @@ -782,8 +782,8 @@ def test_kv_cache_reset_reuse_state(self): # Third request with same tokens - should NOT reuse blocks after reset req3 = self.create_llm_request(2, [1, 2, 3, 4, 5]) - kv_cache_manager.impl.add_sequence(req3.py_request_id, req3.prompt_len, - 1, req3) + kv_cache_manager.impl.add_sequence_batch( + [(req3.py_request_id, req3.prompt_len, 1)], [req3]) stats_after_third = kv_cache_manager.get_kv_cache_stats() self.assertEqual( diff --git a/tests/unittest/_torch/modules/helix_test_utils.py b/tests/unittest/_torch/modules/helix_test_utils.py index 5ee0ecd25999..558bb2dd590a 100644 --- a/tests/unittest/_torch/modules/helix_test_utils.py +++ b/tests/unittest/_torch/modules/helix_test_utils.py @@ -189,6 +189,8 @@ def setup_kv_and_metadata( mapping=mapping, dtype=str_dtype_to_binding(torch_dtype_to_str(scenario.kv_cache_dtype)), ) + requests = [] + request_infos = [] for req_id in range(scenario.batch): req = LlmRequest( request_id=req_id, @@ -199,8 +201,10 @@ def setup_kv_and_metadata( ) req.is_dummy_request = True req.paged_kv_block_ids = [] - beam_width = 1 - kv_cache_manager.impl.add_sequence(req_id, ctx_len_per_gpu, beam_width, req) + requests.append(req) + request_infos.append((req_id, ctx_len_per_gpu, 1)) + kv_cache_manager.impl.add_sequence_batch(request_infos, requests) + for req in requests: req.state = LlmRequestState.GENERATION_IN_PROGRESS req.prompt_len = ctx_len_per_gpu req.py_prompt_len = req.prompt_len diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index ee3e1171e8a8..d7808688bc24 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -625,8 +625,8 @@ def add_and_verify_request( ctx_kv_caches.append(kv_cache) else: for ctx_kv_cache_manager in valid_ctx_kv_cache_managers: - ctx_kv_cache_manager.impl.add_sequence( - ctx_request.py_request_id, ctx_request.prompt_len, 1, ctx_request + ctx_kv_cache_manager.impl.add_sequence_batch( + [(ctx_request.py_request_id, ctx_request.prompt_len, 1)], [ctx_request] ) gen_request = LlmRequest( @@ -656,8 +656,8 @@ def add_and_verify_request( gen_kv_caches.append(kv_cache) else: for gen_kv_cache_manager in valid_gen_kv_cache_managers: - gen_kv_cache_manager.impl.add_sequence( - gen_request.py_request_id, gen_request.prompt_len, 1, gen_request + gen_kv_cache_manager.impl.add_sequence_batch( + [(gen_request.py_request_id, gen_request.prompt_len, 1)], [gen_request] ) # Get block_ids per layer_group with window_size filtering diff --git a/tests/unittest/disaggregated/test_kv_transfer_mp.py b/tests/unittest/disaggregated/test_kv_transfer_mp.py index 70fe25dc11c3..40b55ac719ff 100644 --- a/tests/unittest/disaggregated/test_kv_transfer_mp.py +++ b/tests/unittest/disaggregated/test_kv_transfer_mp.py @@ -319,8 +319,8 @@ def process_and_verify_request( ctx_request.py_disaggregated_params = DisaggregatedParams(disagg_request_id=unique_rid) # Add sequence to KVCacheManager - kv_cache_manager.impl.add_sequence( - ctx_request.py_request_id, ctx_request.prompt_len, 1, ctx_request + kv_cache_manager.impl.add_sequence_batch( + [(ctx_request.py_request_id, ctx_request.prompt_len, 1)], [ctx_request] ) # Create sender session @@ -361,8 +361,8 @@ def process_and_verify_request( ) # Add sequence to KVCacheManager - kv_cache_manager.impl.add_sequence( - gen_request.py_request_id, gen_request.prompt_len, 1, gen_request + kv_cache_manager.impl.add_sequence_batch( + [(gen_request.py_request_id, gen_request.prompt_len, 1)], [gen_request] ) # Create receiver session diff --git a/tests/unittest/disaggregated/test_mamba_transfer.py b/tests/unittest/disaggregated/test_mamba_transfer.py index f637b0e6f146..5be3e43e7d4f 100644 --- a/tests/unittest/disaggregated/test_mamba_transfer.py +++ b/tests/unittest/disaggregated/test_mamba_transfer.py @@ -25,7 +25,7 @@ from tensorrt_llm import DisaggregatedParams, Mapping, SamplingParams from tensorrt_llm._torch.disaggregation.transceiver import KvCacheTransceiverV2 from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestType -from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MambaHybridCacheManager +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import MixedMambaHybridCacheManager from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests from tensorrt_llm.bindings import DataType from tensorrt_llm.bindings.internal.batch_manager import CacheType as CacheTypeCpp @@ -180,7 +180,7 @@ def _init(rank): def _create_managers(tp): - """Create MambaHybridCacheManagers for all TP ranks (PP=1). + """Create MixedMambaHybridCacheManagers for all TP ranks (PP=1). Layer 0 is a dummy attention layer required by page table infrastructure. Layers 1..NUM_MAMBA_LAYERS are mamba layers under test. @@ -188,7 +188,7 @@ def _create_managers(tp): managers = [] for rank in range(tp): mapping = Mapping(world_size=tp, rank=rank, tp_size=tp, pp_size=1) - mgr = MambaHybridCacheManager( + mgr = MixedMambaHybridCacheManager( mamba_d_state=MAMBA_D_STATE, mamba_d_conv=MAMBA_D_CONV, mamba_num_heads=MAMBA_NUM_HEADS, diff --git a/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py b/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py index 68081984d277..4144dc391119 100644 --- a/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py +++ b/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py @@ -588,7 +588,7 @@ def gather_and_verify_request( ) -> tuple[torch.Tensor | None, torch.Tensor | None]: """Gather block data from all ranks to local_rank 0, then verify on world rank 0. - All ranks have all requests' block data (via add_sequence), so gather is simple. + All ranks have all requests' block data (via add_sequence_batch), so gather is simple. In DP mode, merge_block_data knows which ranks have valid (transferred) data. """ blocks = kv_cache_manager.get_batch_cache_indices([request.py_request_id])[0] @@ -682,9 +682,11 @@ def gather_and_verify_request( # Generation DP: only handle if request_index % gen_tp == tp_rank should_handle = i % gen_tp == tp_rank - # All ranks add_sequence so they have block data for verification + # All ranks add_sequence_batch so they have block data for verification # But only ranks that should_handle will submit to transceiver - kv_cache_manager.impl.add_sequence(request.py_request_id, request.prompt_len, 1, request) + kv_cache_manager.impl.add_sequence_batch( + [(request.py_request_id, request.prompt_len, 1)], [request] + ) if should_handle: my_requests.append((i, request)) # Store index and request for transfer diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 0ca0f067611a..a2402e6b6695 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -82,8 +82,8 @@ def test_kv_cache_event_data_serialization(): assert len(serialized_event[0]["data"]["num_blocks_per_cache_level"]) == 2 req = create_llm_request(0, [1, 2, 3, 4, 5]) - kv_cache_manager.impl.add_sequence(req.py_request_id, req.prompt_len, 1, - req) + kv_cache_manager.impl.add_sequence_batch( + [(req.py_request_id, req.prompt_len, 1)], [req]) simulate_prefill_completion_only_use_for_testing(req) kv_cache_manager.free_resources(req) @@ -100,8 +100,8 @@ def test_kv_cache_event_data_serialization(): assert serialized_event[0]["data"]["blocks"][0]["mm_keys"] == [] req2 = create_llm_request(1, [1, 2, 3, 4, 5]) - kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len, 1, - req2) + kv_cache_manager.impl.add_sequence_batch( + [(req2.py_request_id, req2.prompt_len, 1)], [req2]) simulate_prefill_completion_only_use_for_testing(req2) kv_cache_manager.free_resources(req2) diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index 6960b1e28d07..a16581d83c20 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -117,9 +117,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, disagg_request_id=uuid.uuid4().int & 0x7FFFFFFFFFFFFFFF) ctx_request.py_disaggregated_params = disaggregated_params - kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id, - ctx_request.prompt_len, 1, - ctx_request) + kv_cache_manager_ctx.impl.add_sequence_batch( + [(ctx_request.py_request_id, ctx_request.prompt_len, 1)], [ctx_request]) # send ctx request kv_cache_transceiver_ctx.respond_and_send_async(ctx_request) @@ -148,9 +147,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, gen_request.py_disaggregated_params = disaggregated_params - kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id, - gen_request.prompt_len, 1, - gen_request) + kv_cache_manager_gen.impl.add_sequence_batch( + [(gen_request.py_request_id, gen_request.prompt_len, 1)], [gen_request]) # send gen request kv_cache_transceiver_gen.request_and_receive_async(gen_request) @@ -198,9 +196,8 @@ def test_cancel_request_in_transmission(attention_type): is_streaming=False, llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY) - kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id, - ctx_request.prompt_len, 1, - ctx_request) + kv_cache_manager_ctx.impl.add_sequence_batch( + [(ctx_request.py_request_id, ctx_request.prompt_len, 1)], [ctx_request]) # send ctx request kv_cache_transceiver_ctx.respond_and_send_async(ctx_request) @@ -222,9 +219,8 @@ def test_cancel_request_in_transmission(attention_type): llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY, context_phase_params=ctx_request.context_phase_params) - kv_cache_manager_gen.impl.add_sequence(gen_request.py_request_id, - gen_request.prompt_len, 1, - gen_request) + kv_cache_manager_gen.impl.add_sequence_batch( + [(gen_request.py_request_id, gen_request.prompt_len, 1)], [gen_request]) # send gen request kv_cache_transceiver_gen.request_and_receive_async(gen_request)