[Attention] Re-enable cross-layer KV cache layout for MLA via stride-aware kernels#45111
Merged
Conversation
gau-nernst
approved these changes
Jun 11, 2026
This was referenced Jun 18, 2026
| # 3D buffers have no explicit page dim (pages are packed token-major); | ||
| # 4D buffers may have a page stride larger than PAGE_SIZE * token stride | ||
| # (e.g. per-layer views into a cross-layer block-major cache). | ||
| return buf.stride(-4) if buf.dim() >= 4 else page_size * buf.stride(-3) |
Collaborator
There was a problem hiding this comment.
nit, personally i think
def _page_stride(buf, page_size):
if buf.ndim == 3:
buf = buf.unflatten(-3, (-1, page_size))
return buf.stride(-4)
would be a bit clearer than manual stride computations
Collaborator
Author
There was a problem hiding this comment.
Thanks! Just fixed it
LucasWilkinson
approved these changes
Jun 21, 2026
LucasWilkinson
left a comment
Collaborator
There was a problem hiding this comment.
overall looks good to me, thanks for fixing this! left 1 nit
…aware kernels PR vllm-project#37090 disabled the cross-layer (block-major) KV cache layout for all MLA backends, attributing GLM-4.7-Flash garbage output (vllm-project#37032) to "MLA kernels requiring contiguous per-layer views". The actual cause is narrower: a few kernels computed page addresses from block_size * entry_size instead of the cache tensor's block-dim stride. The MLA write path (concat_and_cache_mla), the prefill gather kernels, and several decode kernels already honor tensor strides. Fix the three kernels that assumed packed pages: - triton_decode_attention.py: pass the page-dim stride to both stage-1 kernels instead of linearizing token slots against the intra-page stride (the kernel that produced the original bug report on A100). - sm100_cutlass_mla_kernel.cu: build stride_C from kv_c_and_k_pe_cache.stride(0)/stride(1) instead of hardcoding page_size * (D_latent + D_rope). - indexer_k_quant_and_cache (cache_kernels.cu): use kv_cache.stride(0) for the block base instead of block_size * size(2). This also unblocks the DeepSeek V3.2/V4 indexer KV cache group under cross-layer layouts. Re-enable cross-layer per backend (opt-in) instead of reverting the MLACommonBackend identity default, so unverified backends (ROCm AITER, tokenspeed, XPU) stay safely opted out: - TritonMLA, CutlassMLA: kernels fixed above. - FlashAttnMLA: FA3 reads k_batch_stride = kcache.stride(0). - FlashMLA: dense decode reads k_batch_stride = kcache.stride(0). - FlashInferMLA: trtllm_batch_decode_with_kv_cache_mla verified bit-exact on a strided view (sm100). Add bit-exact strided-view equivalence tests for the triton decode kernel (all three address paths), CUTLASS sm100 MLA decode, FlashMLA dense (sm90-gated) and fp8 sparse decode, FlashInfer MLA dense decode, concat_and_cache_mla, and indexer_k_quant_and_cache. All pass on GB200 (sm100) after rebuilding the C extensions; the FlashMLA dense test requires Hopper. Co-authored-by: Claude Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
…-layer Close the residual coverage gaps from review: bit-exact unified-slot-view equivalence tests for the FA3 decode path (FLASH_ATTN_MLA), FlashMLA dense fp8 decode, and FlashInfer MLA dense fp8 decode. The FlashInfer fp8 test passes on GB200 (sm100); the FA3 and FlashMLA dense tests are gated on Hopper support and skip elsewhere. Co-authored-by: Claude Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
f3134d6 to
fc2a5a9
Compare
6 tasks
Merged
1 task
nkzhenhua
pushed a commit
to nkzhenhua/vllm
that referenced
this pull request
Jun 24, 2026
…aware kernels (vllm-project#45111) Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
qli88
pushed a commit
to qli88/vllm
that referenced
this pull request
Jun 26, 2026
…aware kernels (vllm-project#45111) Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai> Signed-off-by: Qiang Li <qiang.li2@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
#37090 disabled the cross-layer (block-major) KV cache layout for all MLA backends after #37032 (GLM-4.7-Flash garbage output with KV offloading), attributing the bug to "MLA kernels requiring contiguous per-layer KV cache views". The actual cause is narrower: a few kernels computed page addresses from
block_size * entry_sizeinstead of reading the cache tensor's block-dim stride. The MLA write path (concat_and_cache_mla), the prefill gather kernels (cp_gather_cache,gather_and_maybe_dequant_cache), DeepGEMMfp8_paged_mqa_logits, and several decode kernels are already stride-aware.This PR fixes the three kernels that genuinely assumed packed pages and re-enables the cross-layer layout per backend (opt-in), keeping the safe identity default on
MLACommonBackendfor backends not yet verified (ROCm AITER, tokenspeed, XPU).Kernel fixes
vllm/v1/attention/ops/triton_decode_attention.py: both stage-1 decode kernels addressed the cache as(page_number * PAGE_SIZE + offset) * stride(-3), baking instride(block) == PAGE_SIZE * stride(token). They now take the page-dim stride separately. This is the kernel behind the original [Bug]: GLM 4.7-flash returns gibberish when native KV cache offloading is on #37032 report (A100 falls back to TRITON_MLA).csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu:stride_Chardcodedpage_size * (D_latent + D_rope); now built fromkv_c_and_k_pe_cache.stride(0)/stride(1)(identical values for contiguous caches).csrc/libtorch_stable/cache_kernels.cuindexer_k_quant_and_cache: block base wasblock_idx * cache_block_size * kv_cache.size(2); nowblock_idx * kv_cache.stride(0). Writing through a strided view previously corrupted the target layer and bled into neighbouring layers' segments. This also unblocks the DeepSeek V3.2/V4 indexer KV cache group under packed/cross-layer layouts.Per-backend opt-in
get_kv_cache_stride_order(include_num_layers_dimension=True)returns(1, 0, 2, 3)on backends whose decode kernels verifiably honor the cache's block-dim stride: TritonMLA and CutlassMLA (fixed above), FlashAttnMLA (FA3 readsk_batch_stride = kcache.stride(0)), FlashMLA (dense decode readskcache.stride(0)), and FlashInferMLA (verified bit-exact).MLACommonBackendkeeps the identity permutation as the safe default, so unverified backends remain opted out and can opt in individually once verified.Tests
tests/kernels/attention/test_mla_cross_layer_kernel_equivalence.py(new) runs each kernel on a contiguous cache vs a per-layer view carved from a cross-layer buffer (inflatedstride(0), non-zero storage offset, neighbour layers filled with garbage) and asserts bit-exact equality:concat_and_cache_mlawrite (incl. zero bleed), FlashMLA dense decode (Hopper-gated), FlashMLA dense fp8 decode (Hopper-gated), FA3 decode (Hopper-gated), FlashInfer MLA dense decode (bf16 + fp8), FlashMLA fp8 sparse decode, andindexer_k_quant_and_cache(incl. zero bleed). Similar strided-view tests are added for the triton decode kernels (all three address paths: MLA grouped, GQA grouped, MHA normal) and CUTLASS sm100 MLA decode. The MLA stride-order unit tests are updated for the opt-in design.Why this is not duplicating an existing PR
Searched open PRs for cross-layer/stride-order work: #44577 packs DSv4 KV caches into contiguous per-block allocations but only touches allocation/connector/runner plumbing, no kernels — this PR is complementary (the kernel stride fixes here are what make such packed per-block layouts safe for MLA decode/write kernels). #41093 adds cross-layer support on the Mooncake connector side only. The KV-layout refactor series (#44458 draft, #44455, #42374) standardizes layout plumbing and overlaps some files but does not address the packed-page stride bugs or the MLA cross-layer opt-ins. #34742 is the stride-order default refactor referenced in #37090 review and is orthogonal.
Test commands and results
The three Hopper-gated tests were verified to skip cleanly here; they exercise on sm90 CI. Before the fixes, the strided-view tests reproduce the #37032 failure mode: triton/CUTLASS decode read wrong blocks for block_id > 0, and the indexer write corrupts neighbouring layers.
AI assistance disclosure
This PR was developed with AI assistance (Claude Code). The root-cause analysis, kernel fixes, and tests were reviewed line-by-line and the test suite was run on GB200 hardware by the submitter; an independent automated review (Codex) of the final diff reported no findings.
🤖 Generated with Claude Code