Skip to content

Commit 97bc680

Browse files
feat: support kv cache reuse for MLA (#3571)
* support kv cache reuse for MLA load compressed_kv and k_pe and do up-projection use 192/128 head size MLA context kernel support Blackwell and Hopper now Signed-off-by: Zhen Huang <[email protected]> * add CI test Signed-off-by: Zhen Huang <[email protected]> * fix: set k_pe head_num to 1 for kernel 2 and kernel 2V2 Signed-off-by: Mingyang Jiang <[email protected]> * resolve comments Signed-off-by: Zhen Huang <[email protected]> * use GPTJ style RoPE for MLA Signed-off-by: Zhen Huang <[email protected]> * fix rebase error and some docs Signed-off-by: Zhen Huang <[email protected]> * fix kv_lens Signed-off-by: Zhen Huang <[email protected]> * tiny fix Signed-off-by: Zhen Huang <[email protected]> * fix torch compile Signed-off-by: Zhen Huang <[email protected]> * fix: use normal device memory instead of pinned memory for unit test Signed-off-by: Mingyang Jiang <[email protected]> * fix L0 tests Signed-off-by: Zhen Huang <[email protected]> * fix torch compile after rebase Signed-off-by: Zhen Huang <[email protected]> * resolve comments Signed-off-by: Zhen Huang <[email protected]> * resolve comments again Signed-off-by: Zhen Huang <[email protected]> --------- Signed-off-by: Zhen Huang <[email protected]> Signed-off-by: Mingyang Jiang <[email protected]> Signed-off-by: zhhuang-nv <[email protected]> Co-authored-by: Mingyang Jiang <[email protected]>
1 parent b4e5df0 commit 97bc680

32 files changed

+14638
-9067
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ class KVCacheBlockPool
456456
{
457457
public:
458458
SizeType32 numLayers;
459+
SizeType32 kvFactor;
459460
SizeType32 numKvHeads;
460461
SizeType32 sizePerHead;
461462
SizeType32 tokensPerBlock;
@@ -469,10 +470,11 @@ class KVCacheBlockPool
469470
// FP4 KV caches have extra pools that contain second level scales for dequantization.
470471
bool containsBlockScales;
471472

472-
KVCacheBlockPool(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
473-
SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr,
473+
KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
474+
SizeType32 tokensPerBlock, SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr,
474475
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false)
475476
: numLayers(numLayers)
477+
, kvFactor(kvFactor)
476478
, numKvHeads(numKvHeads)
477479
, sizePerHead(sizePerHead)
478480
, tokensPerBlock(tokensPerBlock)

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
557557
mLayerToPoolIndex[layerIdx] = poolIndex;
558558
}
559559
}
560-
mPools.emplace_back(numLayers, numKvHeads, sizePerHead, tokensPerBlock, 1);
560+
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1);
561561
++poolIndex;
562562
}
563563

@@ -649,8 +649,8 @@ void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
649649
TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0,
650650
"Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize.");
651651

652-
mPools.emplace_back(kv_pool.numLayers, kv_pool.numKvHeads, kv_pool.sizePerHead, kv_pool.tokensPerBlock,
653-
quantBlockSize,
652+
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead,
653+
kv_pool.tokensPerBlock, quantBlockSize,
654654
/*primaryPool=*/nullptr,
655655
/*secondaryPool=*/nullptr,
656656
/*containsBlockScales=*/true);

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,14 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
7878
{
7979
auto stream = (isOffload ? mOffloadManager : mOnboardManager).getStream().get();
8080
int const numLayers = pools[poolIdx].numLayers;
81+
int const kvFactor = pools[poolIdx].kvFactor;
8182
int const numHeads = pools[poolIdx].numKvHeads;
8283
int const sizePerHead = pools[poolIdx].sizePerHead;
8384
auto shape = srcPtr->getShape();
8485
TLLM_LOG_DEBUG("block.Shape = %s", srcPtr->toString(shape).c_str());
8586
TLLM_CHECK_WITH_INFO(
8687
shape.nbDims == 4, "Expected KVCache block to have 4 dimensions, but it has %d", shape.nbDims);
87-
TLLM_CHECK_WITH_INFO((shape.d[0] == 1) && (shape.d[1] == numLayers) && (shape.d[2] == 2)
88+
TLLM_CHECK_WITH_INFO((shape.d[0] == 1) && (shape.d[1] == numLayers) && (shape.d[2] == kvFactor)
8889
&& (shape.d[3] == numHeads * tokensPerBlock * sizePerHead),
8990
"Block shape is incorrect");
9091
TLLM_CHECK_WITH_INFO(numTokensToCopy <= tokensPerBlock,

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "attentionOp.h"
1818
#include "tensorrt_llm/common/assert.h"
1919
#include "tensorrt_llm/common/envUtils.h"
20+
#include "tensorrt_llm/common/logger.h"
2021
#include "tensorrt_llm/common/memoryUtils.h"
2122
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
2223
#include "tensorrt_llm/kernels/flashMLA/flash_mla.h"
@@ -1528,12 +1529,44 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15281529
== false,
15291530
"Found invalid number (NaN or Inf) in " + beforeRopeStr);
15301531
}
1532+
1533+
KVBlockArray mla_context_paged_kv_cache_buffer;
15311534
if (mIsMLAEnabled)
15321535
{
15331536
params.mla_param->cache_type = cache_type;
15341537
params.mla_param->cu_q_seqlens = cu_q_seqlens;
15351538
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
1536-
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
1539+
if (mPagedContextFMHA && mPagedKVCache)
1540+
{
1541+
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
1542+
"Paged kv cache is not set for MLA context kernel");
1543+
TLLM_CHECK_WITH_INFO(params.mla_param->context_kv_cache_block_offsets_ptr != nullptr,
1544+
"Paged kv cache block offsets is not set for MLA context kernel");
1545+
// build another KVBlockArray for MLA context kernel to read paged kv cache, which is built by the
1546+
// PyTorch backend assume the dtype of paged kv cache is the same as the T
1547+
auto const elemSize = sizeof(T);
1548+
auto const headSize = params.mla_param->meta.qk_nope_head_dim + params.mla_param->meta.qk_rope_head_dim;
1549+
// mNumKVHeads is 1 for writing, we use mNumHeads for reading paged kv cache
1550+
auto sizePerToken = mNumHeads * headSize * elemSize;
1551+
auto maxBlocksPerSeq = params.mla_param->context_paged_kv_max_blocks_per_seq;
1552+
TLLM_LOG_DEBUG(
1553+
"AttentionOp building KVBlockArray for MLA context kernel, elemSize: %d, headSize: %d, mNumHeads: "
1554+
"%d, sizePerToken: %d, batchSize: %d, maxBlocksPerSeq: %d, tokensPerBlock: %d, maxAttentionWindow: "
1555+
"%d, "
1556+
"sinkTokenLen: %d, canUseOneMoreBlock: %d",
1557+
elemSize, headSize, mNumHeads, sizePerToken, params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
1558+
params.cyclic_attention_window_size, params.sink_token_length, params.can_use_one_more_block);
1559+
mla_context_paged_kv_cache_buffer = KVBlockArray(params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
1560+
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
1561+
params.sink_token_length, params.can_use_one_more_block, params.mla_param->context_paged_kv_ptr,
1562+
nullptr,
1563+
static_cast<KVBlockArray::DataType*>(params.mla_param->context_kv_cache_block_offsets_ptr));
1564+
}
1565+
else
1566+
{
1567+
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if not using paged context FMHA
1568+
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
1569+
}
15371570
}
15381571
else
15391572
{
@@ -1596,7 +1629,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
15961629
fmhaParams.packedMaskPtr = params.attention_packed_mask;
15971630
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
15981631
{
1599-
fmhaParams.pagedKvCache = kv_cache_buffer;
1632+
if (mIsMLAEnabled && mPagedContextFMHA && mPagedKVCache)
1633+
{
1634+
fmhaParams.pagedKvCache = mla_context_paged_kv_cache_buffer;
1635+
fmhaParams.qPtr = reinterpret_cast<void const*>(attention_input);
1636+
}
1637+
else
1638+
{
1639+
fmhaParams.pagedKvCache = kv_cache_buffer;
1640+
}
16001641
}
16011642
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
16021643
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths;
@@ -1612,7 +1653,6 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16121653
// Run the fmha kernel.
16131654
mFmhaDispatcher->run(fmhaParams);
16141655
sync_check_cuda_error(stream);
1615-
16161656
// The kv cache might need to be updated after FMHA (only when sliding window attention + chunked context is
16171657
// used together). Reuse the preprocessingParams.
16181658
invokeKvCachePostprocessing(preprocessingParams, stream);
@@ -2418,9 +2458,8 @@ int AttentionOp::initialize() noexcept
24182458
}
24192459
else
24202460
{
2421-
fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA && !mIsMLAEnabled)
2422-
? AttentionInputLayout::Q_PAGED_KV
2423-
: AttentionInputLayout::PACKED_QKV;
2461+
fmhaParams.attentionInputLayout = (mPagedKVCache && mPagedContextFMHA) ? AttentionInputLayout::Q_PAGED_KV
2462+
: AttentionInputLayout::PACKED_QKV;
24242463
}
24252464
fmhaParams.isSPadded = !mRemovePadding;
24262465
fmhaParams.numQHeads = mNumAttnHeads;

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_so
379379
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_softcapping_sm90_cu_cubin[];
380380
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin[];
381381
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin[];
382-
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin[];
383-
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm90_cu_cubin[];
382+
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin[];
384383
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin[];
384+
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin[];
385385
#endif
386386

387387
#ifndef EXCLUDE_SM_89
@@ -1661,9 +1661,9 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcap
16611661
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_256_softcapping_sm90_cu_cubin_len;
16621662
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin_len;
16631663
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len;
1664-
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin_len;
1665-
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm90_cu_cubin_len;
1664+
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len;
16661665
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len;
1666+
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len;
16671667
#endif
16681668

16691669
#ifndef EXCLUDE_SM_89
@@ -3573,9 +3573,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
35733573
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sliding_window_causal_softcapping_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, true, false},
35743574
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 1, 0, false, true, false, true, true, false, true, false},
35753575
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_softcapping_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_qkv_256_sliding_window_causal_softcapping_sm90_kernel_nl", 49152, 128, 64, 2, 0, false, true, false, true, true, false, true, false},
3576-
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false},
3577-
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, false},
3576+
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false},
35783577
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm90_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, false},
3578+
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false},
35793579
#endif
35803580

35813581
#ifndef EXCLUDE_SM_89

0 commit comments

Comments
 (0)