[Minimax-M3] BF16/FP8 Indexer using MSA#45892
Merged
Merged
Conversation
Contributor
|
This pull request has merge conflicts that must be resolved before it can be |
…Max M3 Add an SM100/Blackwell lightning-indexer impl that computes the per-128-block QK max-scores with fmha_sm100's score-only (OnlyScore) path and selects the top-k blocks with the existing Triton minimax_m3_index_topk kernel, mirroring how the main MSA attention pairs the SM100 attend with Triton. Decode and prefill requests are split manually (decode-first batch) and each side gets its own _fmha_sm100_plan / _fmha_sm100 call. Auto-selected on SM100 when topk_blocks in (4, 8, 16, 32) for both bf16 and fp8 index caches; falls back to the Triton indexer otherwise. The builder declares AttentionCGSupport.NEVER (eager; the attention is broken out of the graph by _run_attention). Extend the fused qknorm+rope+kv-insert kernel to optionally emit fp8 (e4m3) for the index-K cache and index-Q via a direct cast with no scale tensors (RMSNorm outputs are O(1) and scalar scales do not change top-k ordering). Only the index outputs go fp8; q/k/v and q_out stay bf16 and bit-identical to the existing path. MiniMaxM3IndexerCache now accepts fp8 caches and the model allocates index_q in the cache dtype. Tests: test_fmha_sm100_indexer_matches_reference (bf16/fp8 x prefill/decode) and test_msa_indexer_impl_matches_triton (full impl parity vs the Triton indexer through the real metadata builders); fp8 fused-kernel parity is covered in test_fused_minimax_m3_qknorm_rope_kv_insert. AI assistance (Claude Code) was used for this change. Signed-off-by: Yongye Zhu <yongye@inferact.ai> Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
The SM100 (MSA) lightning indexer declared AttentionCGSupport.NEVER and ran eager: its fmha_sm100 score plan allocated fresh buffers every build() and the run allocated a fresh max_score / top-k each call. Make the decode side cudagraph-replay-safe: - Reserve persistent plan buffers in the builder __init__ (sized over a scan of decode sizes [1, max_num_seqs]); build() fills them in place and calls the plan kernel directly with a fixed, batch-size-only num_kv_splits (estimate_num_kv_splits, a uniform-context replica of the planner's auto-split math; tunable via VLLM_M3_INDEXER_CONTEXT_LEN). A positive split count takes the deterministic plan path with no device->host sync. - workspace_o / workspace_lse / cute_workspace are builder-owned dedicated tensors (not the shared global _alloc_workspace_buf cache) so a larger fmha call elsewhere can't realloc and move an address a captured graph baked. - max_score is a 1-D-backed contiguous [H, max_k_tiles, nnz_qo] view (a sliced 3-D buffer is non-contiguous; the kernel assumes contiguous); max_k_tiles pinned for a stable shape. - Top-k output goes to a model-level topk_indices_buffer (DeepSeek-V3.2 pattern), threaded model -> decoder layer -> sparse attention -> indexer; index_topk gains an out= param. - Narrow _run_attention's eager break so the indexer runs in the captured segment and only the sparse attention is eager-broken. Builder reports UNIFORM_BATCH; prefill stays eager. Verified: tests/kernels/attention/test_minimax_m3.py 42/42 pass (incl. impl-level parity with num_kv_splits > 1 over short context); GPQA accuracy matches eager. AI assistance (Claude Code) was used for this change. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…de plan Follow-ups on the cudagraph-capturable MSA indexer: - Top-k: both decode and prefill now write into the single shared, persistent topk_indices_buffer (decode at [:, :nd], prefill at [:, nd:]) and return views into it -- no fresh per-step top-k allocations. - Build the decode plan + flat page table entirely with torch on-GPU: drop numpy and CpuGpuBuffer; segment offsets/lengths are computed via torch.cumsum into the persistent int32 buffers, and the request-major page table is scattered into the buffer via the on-GPU page indptr (the run bounds reads by indptr, so the full buffer is passed and no host page count is needed). - No GPU->CPU sync on the decode path: scalars come from host ints (num_decode_tokens // num_decodes), and seq_lens.cpu() is confined to the eager prefill branch. The impl forward (fmha OnlyScore + Triton top-k) was already sync-free. test_msa_indexer_impl_matches_triton now also asserts both outputs are views into the persistent buffer. 42/42 in test_minimax_m3.py pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…ffer Unify the Triton indexer impl onto the same persistent top-k buffer as the MSA impl, and thread the buffer through the AMD model too. - minimax_m3_index_decode gains an out= param (writes out[:, :total_q]); the merge kernel already writes via strides, so a buffer view works. - MiniMaxM3IndexerTritonImpl.forward writes decode ([:, :nd]) and prefill ([:, nd:]) into the shared topk_indices_buffer and returns views into it (no fresh per-step top-k allocations), matching the MSA impl. - amd/model.py: allocate the model-level topk_indices_buffer and thread it model -> decoder layer -> sparse attention -> indexer (mirrors nvidia). AMD keeps its eager break, so this is purely allocation reuse there. test_msa_indexer_impl_matches_triton now gives each impl its own buffer and asserts both decode/prefill outputs are views into it. 42/42 pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Emit an info_once line naming the chosen indexer impl (MSA fmha_sm100 vs Triton) and the deciding inputs (topk_blocks, indexer_kv_dtype, sm100), so the active kernel path is visible at startup. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Fixes a cudagraph-replay AssertionError on bs=1 decode: with the indexer in the captured segment and only the attend eager-broken, the (decode_topk, prefill_topk) Python tuple handed across the eager break was frozen at capture and replayed stale against the current step's metadata. The attend now reads its top-k directly from the shared persistent ``topk_indices_buffer`` (decode at [:, :nd], prefill at [:, nd:num_tokens]), sliced by the current step's metadata -- nothing crosses the eager break as a Python value. The indexer (captured) writes the buffer; the attend (eager) reads it; the breakable-cudagraph segment ordering guarantees the write precedes the read on replay. - MiniMaxM3SparseImpl.forward (+ Triton and MSA subclasses): drop the topk_idx arg, read layer.topk_indices_buffer. - nvidia model: store the buffer on the sparse-attention layer; narrow the eager break so the indexer is captured and only the attend (_run_sparse_attn) is eager-broken. - amd model: store the buffer on the layer; indexer + attend stay in one eager break (no capture), but the attend reads the buffer too. 42/42 in test_minimax_m3.py pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
estimate_num_kv_splits was a static replica of fmha_sm100's auto cost heuristic, which over-splits the decode score: an fp8 e4m3 cudagraph-timed split sweep shows the fixed-split (cudagraph-safe) path hits a sharp split-KV cliff above ~16 splits with >1 work row -- up to ~9x slower (e.g. reqs=2 ctx=65536: s64=153us vs s16=15us). The previous formula picked 52-64 splits there. Replace it with a measured, static fit (no kernel calls, so it stays cudagraph-stable and sync-free): - bs==1: min(64, kv_tiles) -- no cliff, fill with many splits. - bs>=2: min(num_sms // work_rows, cap, kv_tiles), where the cliff-safe cap is 16 through the ~60-100k context target and 32 for longer context (the cliff ceiling rises with context). Across an fp8 cudagraph sweep this is within ~3% of the best fixed split for every (batch, context) measured. The planner's own auto path is faster for bs>=2 (adaptive split distribution) but is catastrophic for bs==1 (up to ~11x) and needs a device->host sync, so it is unsuitable for the captured decode path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Route the MSA (SM100) indexer decode through the Triton fused minimax_m3_index_decode kernel (the same kernel the Triton indexer impl uses) instead of fmha_sm100's OnlyScore path. For q_len==1 decode the Triton kernel is a purpose-built vector x matrix score with a 256-way split-K and fused split-K top-k, which beats fmha's OnlyScore (wasted MMA tiles on a single query, 64-split cap) by ~1.1-3.7x in benchmarks. It is cudagraph-safe by construction (shape-constant split grids) and writes the shared topk_indices_buffer via out=. Prefill keeps fmha OnlyScore + the single-pass Triton top-k, where fmha is ~3-5x faster for the wide score. This drops the persistent fmha decode plan buffers, the num_kv_splits estimate, and the now-unused VLLM_M3_INDEXER_CONTEXT_LEN env. Add fp8 (e4m3) index KV cache support to the Triton decode score kernel: make the QK MMA accumulate in fp32 (out_dtype=tl.float32) so the per-block max score is exact for the e4m3 cache. Top-k is invariant to positive scalar scaling, so fp8 needs no scale tuning. Tests: tests/kernels/attention/test_minimax_m3.py (44 passed), including a new test_decode_index_topk_fp8 validating fp8 decode top-k vs a dequantized-fp32 reference. mypy + ruff clean. AI assistance was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
The MSA builder/impl were written against the pre-vllm-project#45743 indexer decode API. After rebasing onto main, MiniMaxM3IndexerDecodeMetadata gained a required max_decode_query_len field and minimax_m3_index_decode dropped its sm_scale parameter, so the MSA path crashed at engine init during cudagraph profiling (missing max_decode_query_len). - Pass max_decode_query_len when building decode metadata. - Drop the stale sm_scale arg and pass max_decode_query_len in the decode kernel call, matching the Triton builder/impl. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
1b2e995 to
0b70ba3
Compare
PR vllm-project#93 integrated the MSA indexer but did not update the cmake install rules / setup.py package_data to vendor the new files. Add install rules for csrc/, cutlass/include, cutlass/tools/util/include and the matching package_data globs. (cherry picked from commit ea59ac42ab8e4a68f3bc362034de5199b59246df) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Fixes mypy call-arg errors: two fp8 call sites pass sm_scale (by keyword and positionally) that the reference signature lacked. Add sm_scale (default 1.0) and apply it to the score; top-k selection is invariant to a positive scalar, so existing callers are unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
4 tasks
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
e581190 to
0c5311e
Compare
gau-nernst
approved these changes
Jun 23, 2026
khluu
pushed a commit
that referenced
this pull request
Jun 24, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg> (cherry picked from commit 6691f08)
nkzhenhua
pushed a commit
to nkzhenhua/vllm
that referenced
this pull request
Jun 24, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
qli88
pushed a commit
to qli88/vllm
that referenced
this pull request
Jun 26, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg> Signed-off-by: Qiang Li <qiang.li2@amd.com>
wincent8
pushed a commit
to wincent8/vllm
that referenced
this pull request
Jun 29, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Integrating MSA indexer prefill kernel. Decode remain triton for performance reason
Test Plan
Minimax M3 gsm8k, AIME25, GPQA in TP4
Test Result
with FP8 attention and FP8 Indexer cache.
gsm8k: 92
GPQA-D: 93.6
AIME25: 92.5
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.