Skip to content

[Attention] Re-enable cross-layer KV cache layout for MLA via stride-aware kernels#45111

Merged
njhill merged 4 commits into
vllm-project:mainfrom
ivanium:fix/mla-cross-layer
Jun 22, 2026
Merged

[Attention] Re-enable cross-layer KV cache layout for MLA via stride-aware kernels#45111
njhill merged 4 commits into
vllm-project:mainfrom
ivanium:fix/mla-cross-layer

Conversation

@ivanium

@ivanium ivanium commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Purpose

#37090 disabled the cross-layer (block-major) KV cache layout for all MLA backends after #37032 (GLM-4.7-Flash garbage output with KV offloading), attributing the bug to "MLA kernels requiring contiguous per-layer KV cache views". The actual cause is narrower: a few kernels computed page addresses from block_size * entry_size instead of reading the cache tensor's block-dim stride. The MLA write path (concat_and_cache_mla), the prefill gather kernels (cp_gather_cache, gather_and_maybe_dequant_cache), DeepGEMM fp8_paged_mqa_logits, and several decode kernels are already stride-aware.

This PR fixes the three kernels that genuinely assumed packed pages and re-enables the cross-layer layout per backend (opt-in), keeping the safe identity default on MLACommonBackend for backends not yet verified (ROCm AITER, tokenspeed, XPU).

Kernel fixes

  • vllm/v1/attention/ops/triton_decode_attention.py: both stage-1 decode kernels addressed the cache as (page_number * PAGE_SIZE + offset) * stride(-3), baking in stride(block) == PAGE_SIZE * stride(token). They now take the page-dim stride separately. This is the kernel behind the original [Bug]: GLM 4.7-flash returns gibberish when native KV cache offloading is on #37032 report (A100 falls back to TRITON_MLA).
  • csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu: stride_C hardcoded page_size * (D_latent + D_rope); now built from kv_c_and_k_pe_cache.stride(0)/stride(1) (identical values for contiguous caches).
  • csrc/libtorch_stable/cache_kernels.cu indexer_k_quant_and_cache: block base was block_idx * cache_block_size * kv_cache.size(2); now block_idx * kv_cache.stride(0). Writing through a strided view previously corrupted the target layer and bled into neighbouring layers' segments. This also unblocks the DeepSeek V3.2/V4 indexer KV cache group under packed/cross-layer layouts.

Per-backend opt-in

get_kv_cache_stride_order(include_num_layers_dimension=True) returns (1, 0, 2, 3) on backends whose decode kernels verifiably honor the cache's block-dim stride: TritonMLA and CutlassMLA (fixed above), FlashAttnMLA (FA3 reads k_batch_stride = kcache.stride(0)), FlashMLA (dense decode reads kcache.stride(0)), and FlashInferMLA (verified bit-exact). MLACommonBackend keeps the identity permutation as the safe default, so unverified backends remain opted out and can opt in individually once verified.

Tests

tests/kernels/attention/test_mla_cross_layer_kernel_equivalence.py (new) runs each kernel on a contiguous cache vs a per-layer view carved from a cross-layer buffer (inflated stride(0), non-zero storage offset, neighbour layers filled with garbage) and asserts bit-exact equality: concat_and_cache_mla write (incl. zero bleed), FlashMLA dense decode (Hopper-gated), FlashMLA dense fp8 decode (Hopper-gated), FA3 decode (Hopper-gated), FlashInfer MLA dense decode (bf16 + fp8), FlashMLA fp8 sparse decode, and indexer_k_quant_and_cache (incl. zero bleed). Similar strided-view tests are added for the triton decode kernels (all three address paths: MLA grouped, GQA grouped, MHA normal) and CUTLASS sm100 MLA decode. The MLA stride-order unit tests are updated for the opt-in design.

Why this is not duplicating an existing PR

Searched open PRs for cross-layer/stride-order work: #44577 packs DSv4 KV caches into contiguous per-block allocations but only touches allocation/connector/runner plumbing, no kernels — this PR is complementary (the kernel stride fixes here are what make such packed per-block layouts safe for MLA decode/write kernels). #41093 adds cross-layer support on the Mooncake connector side only. The KV-layout refactor series (#44458 draft, #44455, #42374) standardizes layout plumbing and overlaps some files but does not address the packed-page stride bugs or the MLA cross-layer opt-ins. #34742 is the stride-order default refactor referenced in #37090 review and is orthogonal.

Test commands and results

pytest tests/kernels/attention/test_mla_cross_layer_kernel_equivalence.py -v
# GB200 (sm100), extensions built from source: 5 passed, 3 skipped (Hopper-gated: FA3, FlashMLA dense bf16/fp8)
pytest tests/kernels/attention/test_triton_decode_attention.py -v
# 118 passed (incl. new cross-layer strided-view tests; no regression in existing paged/fp8 paths)
pytest tests/kernels/attention/test_cutlass_mla_decode.py::test_cutlass_mla_decode_cross_layer_view -v
# 1 passed on GB200 (fails against the pre-fix kernel with max diff 0.7, confirming the bug)
pytest tests/v1/kv_connector/unit/test_kv_cache_layout.py -v
# 7 passed

The three Hopper-gated tests were verified to skip cleanly here; they exercise on sm90 CI. Before the fixes, the strided-view tests reproduce the #37032 failure mode: triton/CUTLASS decode read wrong blocks for block_id > 0, and the indexer write corrupts neighbouring layers.

AI assistance disclosure

This PR was developed with AI assistance (Claude Code). The root-cause analysis, kernel fixes, and tests were reviewed line-by-line and the test suite was run on GB200 hardware by the submitter; an independent automated review (Codex) of the final diff reported no findings.

🤖 Generated with Claude Code

# 3D buffers have no explicit page dim (pages are packed token-major);
# 4D buffers may have a page stride larger than PAGE_SIZE * token stride
# (e.g. per-layer views into a cross-layer block-major cache).
return buf.stride(-4) if buf.dim() >= 4 else page_size * buf.stride(-3)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, personally i think

def _page_stride(buf, page_size):
    if buf.ndim == 3:
        buf = buf.unflatten(-3, (-1, page_size))
    return buf.stride(-4)

would be a bit clearer than manual stride computations

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just fixed it

@LucasWilkinson LucasWilkinson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good to me, thanks for fixing this! left 1 nit

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Jun 21, 2026
ivanium added 4 commits June 21, 2026 20:10
…aware kernels

PR vllm-project#37090 disabled the cross-layer (block-major) KV cache layout for all
MLA backends, attributing GLM-4.7-Flash garbage output (vllm-project#37032) to "MLA
kernels requiring contiguous per-layer views". The actual cause is
narrower: a few kernels computed page addresses from
block_size * entry_size instead of the cache tensor's block-dim stride.
The MLA write path (concat_and_cache_mla), the prefill gather kernels,
and several decode kernels already honor tensor strides.

Fix the three kernels that assumed packed pages:
- triton_decode_attention.py: pass the page-dim stride to both stage-1
  kernels instead of linearizing token slots against the intra-page
  stride (the kernel that produced the original bug report on A100).
- sm100_cutlass_mla_kernel.cu: build stride_C from
  kv_c_and_k_pe_cache.stride(0)/stride(1) instead of hardcoding
  page_size * (D_latent + D_rope).
- indexer_k_quant_and_cache (cache_kernels.cu): use kv_cache.stride(0)
  for the block base instead of block_size * size(2). This also
  unblocks the DeepSeek V3.2/V4 indexer KV cache group under
  cross-layer layouts.

Re-enable cross-layer per backend (opt-in) instead of reverting the
MLACommonBackend identity default, so unverified backends (ROCm AITER,
tokenspeed, XPU) stay safely opted out:
- TritonMLA, CutlassMLA: kernels fixed above.
- FlashAttnMLA: FA3 reads k_batch_stride = kcache.stride(0).
- FlashMLA: dense decode reads k_batch_stride = kcache.stride(0).
- FlashInferMLA: trtllm_batch_decode_with_kv_cache_mla verified
  bit-exact on a strided view (sm100).

Add bit-exact strided-view equivalence tests for the triton decode
kernel (all three address paths), CUTLASS sm100 MLA decode, FlashMLA
dense (sm90-gated) and fp8 sparse decode, FlashInfer MLA dense decode,
concat_and_cache_mla, and indexer_k_quant_and_cache. All pass on GB200
(sm100) after rebuilding the C extensions; the FlashMLA dense test
requires Hopper.

Co-authored-by: Claude
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
…-layer

Close the residual coverage gaps from review: bit-exact unified-slot-view
equivalence tests for the FA3 decode path (FLASH_ATTN_MLA), FlashMLA dense
fp8 decode, and FlashInfer MLA dense fp8 decode. The FlashInfer fp8 test
passes on GB200 (sm100); the FA3 and FlashMLA dense tests are gated on
Hopper support and skip elsewhere.

Co-authored-by: Claude
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
@ivanium ivanium force-pushed the fix/mla-cross-layer branch from f3134d6 to fc2a5a9 Compare June 21, 2026 20:10
@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 22, 2026
@njhill njhill merged commit aa4990a into vllm-project:main Jun 22, 2026
203 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 22, 2026
nkzhenhua pushed a commit to nkzhenhua/vllm that referenced this pull request Jun 24, 2026
…aware kernels (vllm-project#45111)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
qli88 pushed a commit to qli88/vllm that referenced this pull request Jun 26, 2026
…aware kernels (vllm-project#45111)

Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

5 participants