[Attention] add triton diff-kv backend for mimo#41797
Conversation
|
Documentation preview: https://vllm--41797.org.readthedocs.build/en/41797/ |
There was a problem hiding this comment.
Code Review
This pull request introduces the TRITON_ATTN_DIFFKV backend and a corresponding Triton kernel to support models with differing K and V head dimensions, such as MiMo-V2. It also updates the model executor to dynamically select between FlashAttention and Triton DiffKV backends based on device compatibility. Review feedback suggests including float16 in the supported KV cache data types and explicitly overriding supports_attn_type to restrict the backend to decoder-only attention, aligning with the kernel's implementation.
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | ||
| "auto", | ||
| "bfloat16", | ||
| ] |
There was a problem hiding this comment.
The supported_kv_cache_dtypes list is missing float16, although the comment above explicitly states that fp16 is supported. This will cause validation errors if a user explicitly sets kv_cache_dtype="float16" for a model using this backend.
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | |
| "auto", | |
| "bfloat16", | |
| ] | |
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | |
| "auto", | |
| "float16", | |
| "bfloat16", | |
| ] |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: be535f6d0e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | ||
| "auto", | ||
| "bfloat16", | ||
| ] |
There was a problem hiding this comment.
Don't advertise explicit bfloat16 KV cache
When MiMo uses TRITON_ATTN_DIFFKV with --kv-cache-dtype bfloat16, this entry lets the backend pass capability checks, but do_kv_cache_update() immediately calls triton_reshape_and_cache_flash_diffkv(), whose assertion only accepts "auto" or quantized cache dtype strings; quantized modes are rejected by this impl as well. In that explicit-bfloat16 configuration the first cache update will fail at runtime, so this backend should either only advertise "auto" or update the cache helper to accept explicit bfloat16.
Useful? React with 👍 / 👎.
|
@ZJY0516 — testing this PR on our 8× DGX Spark / TP=8 stack tonight. Wanted to flag a couple of practical observations + a likely blocker for other people trying it. What works (confirmed before runtime test)
Practical blocker for users running NVFP4 quantsThis isn't your PR's fault, but worth flagging here for anyone reading and trying the same thing:
Workaround we used: copied For A small suggestionCould be worth adding a On deckWill report back with full launch + 30-prompt sweep + accuracy spot-check shortly. Pairing this PR's branch with NCCL 2.30.4 (per @jasl's suggestion in #40969 that resolved the Ray Compiled-DAG wedge for our 8× Spark setup), so MiMo on this stack gets a clean baseline. cc @jasl, @shadowlilac-oss, @haosdent — sharing context. |
|
Update from our 8× DGX Spark / TP=8 test of Boot trace + errorEngine got past auto_map, past arch resolution ( Shape mathMiMo-V2.5 config: At TP=8,
The lukealonso NVFP4 quant was apparently prepared assuming approach (2) — fractional sharding. vLLM at TP=8 expects approach (1) — replication. Off by 160 (the missing 0.5 KV head's worth of K+V dims). Practical workarounds for users
Note: this isn't your PR's issue (#41797 is purely the FA3 → Triton diff-KV backend fallback, which appears to dispatch correctly based on the boot logs). It's a quant-vs-vllm-loader layout mismatch for the specific MiMo arch + TP combination. Net impactFor us: defer MiMo-V2.5 deployment on this 8-node cluster until either a TP=8-compatible NVFP4 quant exists, or until we're willing to run TP=4 with 4 idle nodes. Will keep tracking. Posting cross-reference to #41519 since this is likely something other DGX Spark / TP=8 users will hit. |
|
Quick follow-up for anyone tracking this PR: thanks @ZJY0516 — the DiffKV Triton backend in this PR cleanly resolves the first two blockers I flagged earlier (model_type discovery + attention-backend dispatch for unequal Q/V head_dim). The third blocker (degenerate logits at inference, Filed separately at #42803 with empirical evidence + a proposed patch. Verified working end-to-end on the same 8× DGX Spark / TP=8 setup with |
mgoin
left a comment
There was a problem hiding this comment.
Looks reasonable to me, although a few items for cleanup and a unit test would be good to get in. cc @LucasWilkinson @MatthewBonanni as another attention backend
|
Hi @ZJY0516, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
mgoin
left a comment
There was a problem hiding this comment.
Okay since this is pretty separable and thus easy to delete if there are issues, I'm good with merging. Thanks for keeping it clean!
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: divineearthly <divineearthly@gmail.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Purpose
Fix #41519
Test Plan
Test Result
FA
triton
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.