[v1][kvcache] Honor prefix-cache retention interval for Mamba/linear attention#45845
Conversation
…attention
Wire VLLM_PREFIX_CACHE_RETENTION_INTERVAL to Mamba groups, completing the
existing `# TODO: Support Mamba/linear attention` (only sliding-window
attention honored it before).
Mamba/KDA prefix caching retains a full recurrent-state snapshot once per
block_size-token boundary. At small attention block sizes (e.g. 128 under
decoupled hybrid paging) each snapshot spans several base blocks, and dense
per-boundary retention saturates the KV pool — at block_size 128 the Mamba
snapshots occupy ~80% of the blocks — leaving no uncached headroom, so the
allocator is forced to evict live attention prefixes. The prefix-cache hit
rate then collapses late in long multi-turn runs (~85%, down to ~75% under
load) with ~18% lower throughput and ~3x worse p99, while larger block sizes
(256/512) are unaffected.
MambaManager.reachable_block_mask now sparsifies state-snapshot retention the
same way SlidingWindowManager does: keep one cached state per
retention_interval-sized segment (plus the latest replay boundary) instead of
one per block. A hit resumes from the nearest retained boundary (at most
retention_interval tokens coarser), costing negligible extra prefill while
freeing the intermediate snapshots for reuse. Also fixes
MambaManager.cache_blocks to tolerate sparse (unhashed) blocks in the cached
range, and relaxes _validate_prefix_cache_retention_interval to accept models
with a Mamba group.
Validated on Kimi-Linear-48B-A3B-Instruct (decoupled hybrid paging, block_size
128, TP4, multi-turn prefix-on, 50 prompts x 60 turns): with
VLLM_PREFIX_CACHE_RETENTION_INTERVAL=512 the prefix-cache hit rate recovers to
98.5% (parity with block_size 512), throughput to ~200K tok/s, with zero
failed requests. Default behavior (interval unset) is unchanged: Mamba caches
densely.
Test commands run:
.venv/bin/python -m pytest tests/v1/core/test_prefix_caching.py \
-k "retention or reachable" -v # 10 passed
.venv/bin/python -m pytest tests/v1/core/ \
-k "retention or decoupled or buddy or mamba or prefix_cach" -q # 136 passed
This is not a duplicate: no open PR wires retention-interval sparsification to
Mamba/linear-attention groups (the codebase carried it as a TODO).
AI assistance (Claude) was used for this change; the human submitter has
reviewed every changed line.
Signed-off-by: Dao Le <daole@inferact.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
2065198 to
e5df62d
Compare
|
Looks good to me. Thanks! Just need to update the description for Line 1055 in 475a6ad |
|
Thanks for the PR. I have a small clarification question. My understanding is that Kimi-Linear-48B-A3B-Instruct prefix caching currently only supports |
|
Hi @QilaiZhang, in align-mode, it still deposits a snapshot every block_size tokens. Those linger as unreferenced-but-cached blocks and accumulate as context grows. This causes block reused and a drop in prefix cache hit rate. Live footprint stays ~2 blocks; cached footprint grows ~1 per block_size tokens, which is what the retention interval thins. (I was running in a test setting which work around the uniform block-size of HMA.) |
|
@Dao007forever Thanks, that makes sense regarding freed state blocks remaining as unreferenced cached blocks. One follow-up: in upstream Was your Kimi-Linear benchmark configured so that aligned prefill chunks are effectively one |
|
You are right that in prefill, we snapshot per chunk end, but in decode, we snapshot every block size still. |
|
Thanks, that makes sense. I was only thinking about the prefill path and missed the decode-side behavior across block boundaries. The distinction between live footprint and unreferenced-but-cached footprint is helpful. Thanks for clarifying. |
| # (2) Replay boundary. ``get_computed_blocks`` caps hits at | ||
| # ``num_prompt - 1``, so an exact prompt replay lands on the latest | ||
| # fine-aligned boundary. Sparse retention would otherwise skip its | ||
| # state, so keep it explicitly. | ||
| if num_prompt_tokens is not None: | ||
| latest = (num_prompt_tokens - 1) // alignment_tokens * alignment_tokens | ||
| boundary_block = latest // block_size - 1 | ||
| if start_block <= boundary_block < end_block: | ||
| mask[boundary_block - start_block] = True | ||
|
|
||
| return mask |
There was a problem hiding this comment.
not sure if this part can work because for mamba, we need scheduler side changes to cache the end of the prompt. Maybe we can raise NotImplementedError when num_prompt_tokens is given?
There was a problem hiding this comment.
Good catch to raise it! I think we're safe here though — the mask is purely subtractive. A True never forces a block to be cached, it just declines to skip it; the real gate is blk.is_null in cache_full_blocks:
if blk.is_null or (block_mask is not None and not block_mask[i]):
continue
So the layers stay separate: the scheduler decides where a snapshot lands (which blocks are non-null), and the mask just picks which existing ones to keep. If there's no state at the boundary, it's a null_block and gets skipped regardless of the mask — so we can never fabricate a cache entry over a stateless block. Worst case without scheduler changes is that the branch is inert, not incorrect (and in decode we snapshot every block_size, so it usually does fire).
| # (2) Replay boundary. ``get_computed_blocks`` caps hits at | ||
| # ``num_prompt - 1``, so an exact prompt replay lands on the latest | ||
| # fine-aligned boundary. Sparse retention would otherwise skip its | ||
| # state, so keep it explicitly. | ||
| if num_prompt_tokens is not None: | ||
| latest = (num_prompt_tokens - 1) // alignment_tokens * alignment_tokens | ||
| boundary_block = latest // block_size - 1 | ||
| if start_block <= boundary_block < end_block: | ||
| mask[boundary_block - start_block] = True | ||
|
|
||
| return mask |
…attention (vllm-project#45845) Signed-off-by: Dao Le <daole@inferact.ai> Signed-off-by: Dao Le <Dao007forever@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…attention (vllm-project#45845) Signed-off-by: Dao Le <daole@inferact.ai> Signed-off-by: Dao Le <Dao007forever@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Qiang Li <qiang.li2@amd.com>
Purpose
Wire
VLLM_PREFIX_CACHE_RETENTION_INTERVALto Mamba/linear-attention KV-cache groups, completing the in-code# TODO: Support Mamba/linear attentionleft by #43447 (which added the mechanism for sliding-window attention only).Background. Mamba/KDA prefix caching retains a full recurrent-state snapshot once per
block_size-token boundary. At small attention block sizes (e.g. 128) each snapshot spans several base blocks, and dense per-boundary retention saturates the KV pool — atblock_size128 the Mamba snapshots occupy ~80% of the blocks — leaving no uncached headroom, so the allocator is forced to evict live attention prefixes. The prefix-cache hit rate then collapses late in long multi-turn runs, dragging down throughput and tail latency, while larger block sizes (256/512) are unaffected (this is in a test setting which allows shorter block-size)Change.
MambaManager.reachable_block_masknow sparsifies state-snapshot retention the same waySlidingWindowManagerdoes: keep one cached state perretention_interval-sized segment (plus the latest replay boundary) instead of one per block. A hit resumes from the nearest retained boundary (at mostretention_intervaltokens coarser), costing negligible extra prefill while freeing the intermediate snapshots for reuse. Also:MambaManager.cache_blocksnow tolerates sparse (unhashed) blocks in the cached range._validate_prefix_cache_retention_intervalnow accepts models with a Mamba group.Default behavior is unchanged: with the interval unset, Mamba caches densely (every boundary), exactly as before.
Why this is not a duplicate
VLLM_PREFIX_CACHE_RETENTION_INTERVALfor sliding-window KV cache only, and explicitly left# TODO: Support Mamba/linear attention. This PR completes that TODO.reachable_block_mask.Test Plan
E2E: Kimi-Linear-48B-A3B-Instruct (
block_size512, TP4, multi-turn prefix-on, 50 prompts × 60 turns), comparing default vsVLLM_PREFIX_CACHE_RETENTION_INTERVAL=2048.Test Result
Unit tests (CPU; this change is pure-Python KV-cache scheduling logic):
test_prefix_caching.py -k "retention or reachable"— including the newtest_mamba_reachable_block_mask_sparsifies_retention— 10 passed.tests/v1/core/ -k "retention or mamba or prefix_cach"— all change-relevant tests pass in a CPU-only env; the remaining e2e/flash-attn tests in this selection require a full CUDA build + GPU and were not run here.E2E: validated on Kimi-Linear-48B-A3B-Instruct with the config above; with
VLLM_PREFIX_CACHE_RETENTION_INTERVAL=2048AI assistance disclosure: AI assistance (Claude) was used for this change. The human submitter has reviewed every changed line and run the tests above.
🤖 Generated with Claude Code