Skip to content

[Minimax-M3] BF16/FP8 Indexer using MSA#45892

Merged
WoosukKwon merged 14 commits into
vllm-project:mainfrom
zyongye:feat/fp8_indexer
Jun 23, 2026
Merged

[Minimax-M3] BF16/FP8 Indexer using MSA#45892
WoosukKwon merged 14 commits into
vllm-project:mainfrom
zyongye:feat/fp8_indexer

Conversation

@zyongye

@zyongye zyongye commented Jun 17, 2026

Copy link
Copy Markdown
Member

Purpose

Integrating MSA indexer prefill kernel. Decode remain triton for performance reason

Test Plan

Minimax M3 gsm8k, AIME25, GPQA in TP4

Test Result

with FP8 attention and FP8 Indexer cache.
gsm8k: 92
GPQA-D: 93.6
AIME25: 92.5


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
@mergify

mergify Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 17, 2026
zyongye and others added 9 commits June 17, 2026 05:32
…Max M3

Add an SM100/Blackwell lightning-indexer impl that computes the per-128-block
QK max-scores with fmha_sm100's score-only (OnlyScore) path and selects the
top-k blocks with the existing Triton minimax_m3_index_topk kernel, mirroring
how the main MSA attention pairs the SM100 attend with Triton. Decode and
prefill requests are split manually (decode-first batch) and each side gets its
own _fmha_sm100_plan / _fmha_sm100 call. Auto-selected on SM100 when
topk_blocks in (4, 8, 16, 32) for both bf16 and fp8 index caches; falls back to
the Triton indexer otherwise. The builder declares AttentionCGSupport.NEVER
(eager; the attention is broken out of the graph by _run_attention).

Extend the fused qknorm+rope+kv-insert kernel to optionally emit fp8 (e4m3) for
the index-K cache and index-Q via a direct cast with no scale tensors (RMSNorm
outputs are O(1) and scalar scales do not change top-k ordering). Only the index
outputs go fp8; q/k/v and q_out stay bf16 and bit-identical to the existing
path. MiniMaxM3IndexerCache now accepts fp8 caches and the model allocates
index_q in the cache dtype.

Tests: test_fmha_sm100_indexer_matches_reference (bf16/fp8 x prefill/decode) and
test_msa_indexer_impl_matches_triton (full impl parity vs the Triton indexer
through the real metadata builders); fp8 fused-kernel parity is covered in
test_fused_minimax_m3_qknorm_rope_kv_insert.

AI assistance (Claude Code) was used for this change.

Signed-off-by: Yongye Zhu <yongye@inferact.ai>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
The SM100 (MSA) lightning indexer declared AttentionCGSupport.NEVER and ran
eager: its fmha_sm100 score plan allocated fresh buffers every build() and the
run allocated a fresh max_score / top-k each call.

Make the decode side cudagraph-replay-safe:
- Reserve persistent plan buffers in the builder __init__ (sized over a scan of
  decode sizes [1, max_num_seqs]); build() fills them in place and calls the plan
  kernel directly with a fixed, batch-size-only num_kv_splits
  (estimate_num_kv_splits, a uniform-context replica of the planner's auto-split
  math; tunable via VLLM_M3_INDEXER_CONTEXT_LEN). A positive split count takes the
  deterministic plan path with no device->host sync.
- workspace_o / workspace_lse / cute_workspace are builder-owned dedicated tensors
  (not the shared global _alloc_workspace_buf cache) so a larger fmha call
  elsewhere can't realloc and move an address a captured graph baked.
- max_score is a 1-D-backed contiguous [H, max_k_tiles, nnz_qo] view (a sliced 3-D
  buffer is non-contiguous; the kernel assumes contiguous); max_k_tiles pinned for
  a stable shape.
- Top-k output goes to a model-level topk_indices_buffer (DeepSeek-V3.2 pattern),
  threaded model -> decoder layer -> sparse attention -> indexer; index_topk gains
  an out= param.
- Narrow _run_attention's eager break so the indexer runs in the captured segment
  and only the sparse attention is eager-broken. Builder reports UNIFORM_BATCH;
  prefill stays eager.

Verified: tests/kernels/attention/test_minimax_m3.py 42/42 pass (incl. impl-level
parity with num_kv_splits > 1 over short context); GPQA accuracy matches eager.

AI assistance (Claude Code) was used for this change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…de plan

Follow-ups on the cudagraph-capturable MSA indexer:

- Top-k: both decode and prefill now write into the single shared, persistent
  topk_indices_buffer (decode at [:, :nd], prefill at [:, nd:]) and return views
  into it -- no fresh per-step top-k allocations.
- Build the decode plan + flat page table entirely with torch on-GPU: drop numpy
  and CpuGpuBuffer; segment offsets/lengths are computed via torch.cumsum into the
  persistent int32 buffers, and the request-major page table is scattered into the
  buffer via the on-GPU page indptr (the run bounds reads by indptr, so the full
  buffer is passed and no host page count is needed).
- No GPU->CPU sync on the decode path: scalars come from host ints
  (num_decode_tokens // num_decodes), and seq_lens.cpu() is confined to the eager
  prefill branch. The impl forward (fmha OnlyScore + Triton top-k) was already
  sync-free.

test_msa_indexer_impl_matches_triton now also asserts both outputs are views into
the persistent buffer. 42/42 in test_minimax_m3.py pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…ffer

Unify the Triton indexer impl onto the same persistent top-k buffer as the MSA
impl, and thread the buffer through the AMD model too.

- minimax_m3_index_decode gains an out= param (writes out[:, :total_q]); the merge
  kernel already writes via strides, so a buffer view works.
- MiniMaxM3IndexerTritonImpl.forward writes decode ([:, :nd]) and prefill ([:, nd:])
  into the shared topk_indices_buffer and returns views into it (no fresh per-step
  top-k allocations), matching the MSA impl.
- amd/model.py: allocate the model-level topk_indices_buffer and thread it
  model -> decoder layer -> sparse attention -> indexer (mirrors nvidia). AMD keeps
  its eager break, so this is purely allocation reuse there.

test_msa_indexer_impl_matches_triton now gives each impl its own buffer and asserts
both decode/prefill outputs are views into it. 42/42 pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Emit an info_once line naming the chosen indexer impl (MSA fmha_sm100 vs Triton)
and the deciding inputs (topk_blocks, indexer_kv_dtype, sm100), so the active
kernel path is visible at startup.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Fixes a cudagraph-replay AssertionError on bs=1 decode: with the indexer in the
captured segment and only the attend eager-broken, the (decode_topk, prefill_topk)
Python tuple handed across the eager break was frozen at capture and replayed
stale against the current step's metadata.

The attend now reads its top-k directly from the shared persistent
``topk_indices_buffer`` (decode at [:, :nd], prefill at [:, nd:num_tokens]),
sliced by the current step's metadata -- nothing crosses the eager break as a
Python value. The indexer (captured) writes the buffer; the attend (eager) reads
it; the breakable-cudagraph segment ordering guarantees the write precedes the
read on replay.

- MiniMaxM3SparseImpl.forward (+ Triton and MSA subclasses): drop the topk_idx
  arg, read layer.topk_indices_buffer.
- nvidia model: store the buffer on the sparse-attention layer; narrow the eager
  break so the indexer is captured and only the attend (_run_sparse_attn) is
  eager-broken.
- amd model: store the buffer on the layer; indexer + attend stay in one eager
  break (no capture), but the attend reads the buffer too.

42/42 in test_minimax_m3.py pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
estimate_num_kv_splits was a static replica of fmha_sm100's auto cost heuristic,
which over-splits the decode score: an fp8 e4m3 cudagraph-timed split sweep shows
the fixed-split (cudagraph-safe) path hits a sharp split-KV cliff above ~16
splits with >1 work row -- up to ~9x slower (e.g. reqs=2 ctx=65536: s64=153us vs
s16=15us). The previous formula picked 52-64 splits there.

Replace it with a measured, static fit (no kernel calls, so it stays
cudagraph-stable and sync-free):
- bs==1: min(64, kv_tiles) -- no cliff, fill with many splits.
- bs>=2: min(num_sms // work_rows, cap, kv_tiles), where the cliff-safe cap is
  16 through the ~60-100k context target and 32 for longer context (the cliff
  ceiling rises with context).

Across an fp8 cudagraph sweep this is within ~3% of the best fixed split for
every (batch, context) measured. The planner's own auto path is faster for
bs>=2 (adaptive split distribution) but is catastrophic for bs==1 (up to ~11x)
and needs a device->host sync, so it is unsuitable for the captured decode path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Route the MSA (SM100) indexer decode through the Triton fused
minimax_m3_index_decode kernel (the same kernel the Triton indexer impl
uses) instead of fmha_sm100's OnlyScore path. For q_len==1 decode the
Triton kernel is a purpose-built vector x matrix score with a 256-way
split-K and fused split-K top-k, which beats fmha's OnlyScore (wasted MMA
tiles on a single query, 64-split cap) by ~1.1-3.7x in benchmarks. It is
cudagraph-safe by construction (shape-constant split grids) and writes the
shared topk_indices_buffer via out=. Prefill keeps fmha OnlyScore + the
single-pass Triton top-k, where fmha is ~3-5x faster for the wide score.

This drops the persistent fmha decode plan buffers, the num_kv_splits
estimate, and the now-unused VLLM_M3_INDEXER_CONTEXT_LEN env.

Add fp8 (e4m3) index KV cache support to the Triton decode score kernel:
make the QK MMA accumulate in fp32 (out_dtype=tl.float32) so the per-block
max score is exact for the e4m3 cache. Top-k is invariant to positive
scalar scaling, so fp8 needs no scale tuning.

Tests: tests/kernels/attention/test_minimax_m3.py (44 passed), including a
new test_decode_index_topk_fp8 validating fp8 decode top-k vs a
dequantized-fp32 reference. mypy + ruff clean. AI assistance was used.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
The MSA builder/impl were written against the pre-vllm-project#45743 indexer decode
API. After rebasing onto main, MiniMaxM3IndexerDecodeMetadata gained a
required max_decode_query_len field and minimax_m3_index_decode dropped
its sm_scale parameter, so the MSA path crashed at engine init during
cudagraph profiling (missing max_decode_query_len).

- Pass max_decode_query_len when building decode metadata.
- Drop the stale sm_scale arg and pass max_decode_query_len in the decode
  kernel call, matching the Triton builder/impl.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the feat/fp8_indexer branch from 1b2e995 to 0b70ba3 Compare June 17, 2026 05:57
@mergify mergify Bot removed the needs-rebase label Jun 17, 2026
gau-nernst and others added 3 commits June 17, 2026 16:16
PR vllm-project#93 integrated the MSA indexer but did not update the cmake install
rules / setup.py package_data to vendor the new files. Add install rules
for csrc/, cutlass/include, cutlass/tools/util/include and the matching
package_data globs.

(cherry picked from commit ea59ac42ab8e4a68f3bc362034de5199b59246df)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Fixes mypy call-arg errors: two fp8 call sites pass sm_scale (by
keyword and positionally) that the reference signature lacked. Add
sm_scale (default 1.0) and apply it to the score; top-k selection is
invariant to a positive scalar, so existing callers are unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 18, 2026
@mergify mergify Bot added the ci/build label Jun 18, 2026
@zyongye zyongye mentioned this pull request Jun 18, 2026
8 tasks
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the feat/fp8_indexer branch from e581190 to 0c5311e Compare June 22, 2026 04:56
@WoosukKwon WoosukKwon merged commit 6691f08 into vllm-project:main Jun 23, 2026
205 of 212 checks passed
@zyongye zyongye added this to the v0.24.0 cherrypick milestone Jun 23, 2026
khluu pushed a commit that referenced this pull request Jun 24, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
(cherry picked from commit 6691f08)
nkzhenhua pushed a commit to nkzhenhua/vllm that referenced this pull request Jun 24, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
qli88 pushed a commit to qli88/vllm that referenced this pull request Jun 26, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
wincent8 pushed a commit to wincent8/vllm that referenced this pull request Jun 29, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

3 participants