Skip to content

[Kernel] Add FlashInferCutedslMxfp8LinearKernel (cute-dsl mm_mxfp8)#46393

Merged
vllm-bot merged 1 commit into
vllm-project:mainfrom
zyongye:feat/mxfp8_cutedsl_linear
Jun 23, 2026
Merged

[Kernel] Add FlashInferCutedslMxfp8LinearKernel (cute-dsl mm_mxfp8)#46393
vllm-bot merged 1 commit into
vllm-project:mainfrom
zyongye:feat/mxfp8_cutedsl_linear

Conversation

@zyongye

@zyongye zyongye commented Jun 22, 2026

Copy link
Copy Markdown
Member

Add an MXFP8 W8A8 linear GEMM that drives FlashInfer's mm_mxfp8(..., backend="cute-dsl"), sibling to the existing CUTLASS kernel. The cute-dsl backend consumes the same 1D swizzled F8_128x4 scales the CUTLASS path already produces, so weight/activation prep is identical and output is bit-identical; only the backend string and support gate differ.

Gate to sm_100/sm_103 (matching FlashInfer's supported_compute_capability [100, 103]) plus has_flashinfer_cutedsl(); auto-selects on those archs and falls through to CUTLASS elsewhere. Reachable explicitly via --linear-backend flashinfer_cutedsl.

Purpose

Perf comparison between various mxfp8 linear backend

MLP grid — H=6144, I=3072 (TP shards intermediate)

up_proj, TP=1a[m,6144] @ b[6144,3072] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 3 (11.9) 3 (10.9) 3 (11.3) 3 (11.7)
4 13 (11.9) 12 (13.1) 13 (11.2) 13 (11.7)
16 52 (11.7) 46 (13.1) 53 (11.3) 62 (9.8)
32 103 (11.7) 91 (13.2) 113 (10.7) 118 (10.3)
64 206 (11.7) 181 (13.3) 221 (10.9) 215 (11.2)
128 405 (11.9) 359 (13.5) 427 (11.3) 425 (11.4)
256 821 (11.8) 729 (13.2) 765 (12.6) 821 (11.8)
512 1,288 (15.0) 1,529 (12.6) 851 (22.7) 1,477 (13.1)
1024 1,585 (24.4) 1,948 (19.8) 1,139 (34.0) 2,233 (17.3)
2048 2,196 (35.2) 2,307 (33.5) 1,210 (63.9) 2,601 (29.7)
4096 2,069 (74.7) 2,521 (61.3) 1,359 (113.8) 2,528 (61.2)
8192 2,734 (113.1) 2,602 (118.8) 1,496 (206.8) 2,760 (112.0)

up_proj, TP=2a[m,6144] @ b[6144,1536] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 2 (10.8) 2 (10.4) 2 (10.8) 2 (11.1)
4 7 (10.8) 7 (10.2) 7 (10.8) 7 (11.0)
16 28 (10.8) 30 (10.2) 28 (10.7) 33 (9.3)
32 57 (10.5) 57 (10.7) 60 (10.0) 62 (9.7)
64 115 (10.5) 112 (10.8) 114 (10.6) 115 (10.5)
128 227 (10.6) 214 (11.3) 232 (10.4) 231 (10.5)
256 451 (10.7) 386 (12.5) 443 (10.9) 456 (10.6)
512 846 (11.4) 759 (12.7) 795 (12.2) 863 (11.2)
1024 1,336 (14.5) 1,565 (12.4) 868 (22.3) 1,438 (13.4)
2048 1,882 (20.5) 1,974 (19.6) 1,154 (33.5) 2,237 (17.3)
4096 2,235 (34.6) 2,310 (33.5) 1,212 (63.8) 2,554 (30.3)
8192 2,445 (63.2) 2,546 (60.7) 1,344 (115.0) 2,426 (63.7)

up_proj, TP=4a[m,6144] @ b[6144,768] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 1 (10.0) 1 (10.3) 1 (10.1) 1 (10.4)
4 4 (9.9) 4 (10.7) 4 (9.6) 4 (9.9)
16 15 (10.0) 15 (10.3) 15 (9.8) 18 (8.4)
32 30 (9.9) 29 (10.3) 31 (9.7) 34 (8.9)
64 62 (9.8) 56 (10.8) 63 (9.5) 61 (9.9)
128 121 (10.0) 111 (10.9) 124 (9.8) 126 (9.6)
256 244 (9.9) 214 (11.3) 238 (10.1) 243 (10.0)
512 466 (10.4) 391 (12.4) 453 (10.7) 447 (10.8)
1024 890 (10.9) 751 (12.9) 801 (12.1) 841 (11.5)
2048 1,330 (14.5) 1,549 (12.5) 859 (22.5) 1,428 (13.5)
4096 1,579 (24.5) 1,939 (19.9) 1,144 (33.8) 2,068 (18.7)
8192 2,042 (37.9) 2,072 (37.3) 1,176 (65.8) 2,290 (33.8)

up_proj, TP=8a[m,6144] @ b[6144,384] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 0 (9.5) 0 (10.1) 0 (9.8) 0 (9.9)
4 2 (9.5) 2 (10.0) 2 (9.8) 2 (9.4)
16 8 (9.6) 8 (9.9) 8 (9.5) 9 (8.1)
32 16 (9.4) 15 (10.1) 16 (9.3) 17 (8.6)
64 32 (9.5) 29 (10.4) 32 (9.4) 32 (9.4)
128 62 (9.8) 56 (10.8) 63 (9.6) 64 (9.4)
256 126 (9.6) 115 (10.5) 123 (9.9) 126 (9.6)
512 244 (9.9) 218 (11.1) 240 (10.1) 248 (9.7)
1024 470 (10.3) 384 (12.6) 449 (10.8) 476 (10.1)
2048 816 (11.8) 761 (12.7) 768 (12.6) 830 (11.6)
4096 1,261 (15.3) 1,357 (14.2) 831 (23.3) 1,354 (14.3)
8192 1,379 (28.0) 1,535 (25.2) 1,056 (36.6) 1,974 (19.6)

down_proj, TP=1a[m,3072] @ b[3072,6144] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 4 (9.1) 4 (9.6) 4 (8.4) 4 (9.0)
4 17 (9.1) 16 (9.6) 18 (8.2) 17 (9.0)
16 66 (9.1) 63 (9.6) 74 (8.2) 76 (8.0)
32 132 (9.2) 125 (9.7) 145 (8.4) 150 (8.0)
64 260 (9.3) 224 (10.8) 282 (8.6) 277 (8.7)
128 515 (9.4) 440 (11.0) 515 (9.4) 532 (9.1)
256 893 (10.8) 1,038 (9.3) 640 (15.1) 950 (10.2)
512 1,256 (15.4) 1,418 (13.6) 956 (20.2) 1,594 (12.1)
1024 1,873 (20.6) 1,961 (19.7) 1,116 (34.6) 2,130 (18.1)
2048 2,463 (31.4) 2,609 (29.6) 1,264 (61.2) 2,438 (31.7)
4096 2,542 (60.8) 2,523 (61.3) 1,376 (112.4) 2,739 (56.4)
8192 2,830 (109.3) 2,734 (113.1) 1,492 (207.3) 2,811 (110.0)

down_proj, TP=2a[m,1536] @ b[1536,6144] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 3 (6.8) 3 (7.1) 3 (6.2) 3 (7.4)
4 11 (6.7) 11 (6.8) 12 (6.2) 12 (6.6)
16 45 (6.7) 44 (6.8) 49 (6.2) 51 (6.0)
32 90 (6.7) 79 (7.6) 97 (6.2) 104 (5.8)
64 179 (6.8) 157 (7.7) 186 (6.5) 193 (6.3)
128 345 (7.0) 306 (7.9) 353 (6.8) 377 (6.4)
256 553 (8.7) 719 (6.7) 475 (10.2) 645 (7.5)
512 912 (10.6) 1,056 (9.2) 746 (13.0) 1,094 (8.8)
1024 1,408 (13.7) 1,525 (12.7) 950 (20.4) 1,577 (12.3)
2048 2,007 (19.3) 2,165 (17.9) 1,146 (33.7) 2,023 (19.1)
4096 2,488 (31.1) 2,341 (33.0) 1,249 (61.9) 2,483 (31.1)
8192 2,661 (58.1) 2,671 (57.9) 1,323 (116.8) 2,543 (60.8)

down_proj, TP=4a[m,768] @ b[768,6144] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 2 (5.2) 2 (5.3) 2 (4.9) 2 (5.1)
4 7 (5.2) 7 (5.2) 8 (5.0) 7 (5.1)
16 29 (5.2) 29 (5.3) 30 (5.0) 33 (4.6)
32 59 (5.2) 54 (5.6) 60 (5.0) 66 (4.6)
64 117 (5.2) 108 (5.6) 116 (5.2) 125 (4.8)
128 219 (5.5) 210 (5.8) 219 (5.5) 244 (5.0)
256 393 (6.1) 427 (5.7) 331 (7.3) 422 (5.7)
512 654 (7.4) 803 (6.0) 526 (9.2) 665 (7.3)
1024 965 (10.0) 1,198 (8.1) 726 (13.3) 1,067 (9.1)
2048 1,411 (13.7) 1,510 (12.8) 936 (20.6) 1,533 (12.6)
4096 2,013 (19.2) 1,876 (20.6) 1,081 (35.7) 1,993 (19.4)
8192 2,450 (31.6) 2,397 (32.3) 1,148 (67.3) 2,301 (33.6)

down_proj, TP=8a[m,384] @ b[384,6144] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 1 (4.3) 1 (4.3) n/s 1 (4.1)
4 4 (4.3) 4 (4.3) n/s 5 (4.1)
16 18 (4.3) 18 (4.3) n/s 20 (3.7)
32 35 (4.3) 35 (4.3) n/s 39 (3.8)
64 70 (4.3) 70 (4.3) n/s 77 (3.9)
128 140 (4.3) 133 (4.5) n/s 145 (4.2)
256 248 (4.9) 264 (4.6) n/s 257 (4.7)
512 419 (5.8) 463 (5.2) n/s 452 (5.3)
1024 602 (8.0) 759 (6.4) n/s 748 (6.5)
2048 944 (10.2) 1,049 (9.2) n/s 1,003 (9.6)
4096 1,240 (15.6) 1,351 (14.3) n/s 1,291 (15.0)
8192 1,577 (24.5) 1,624 (23.8) n/s 1,550 (24.9)

TP=4 num_tokens sweep (m = 1…4096)

up_proj, TP=4a[m,6144] @ b[6144,768] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 1 (10.0) 1 (10.4) 1 (9.6) 1 (10.5)
2 2 (10.2) 2 (10.3) 2 (10.1) 2 (10.4)
4 4 (10.1) 4 (10.4) 4 (9.6) 4 (9.9)
8 8 (10.0) 7 (10.4) 7 (10.2) 9 (8.2)
16 15 (10.0) 15 (10.3) 16 (9.6) 18 (8.4)
32 30 (10.1) 29 (10.3) 31 (9.8) 34 (8.8)
64 60 (10.1) 58 (10.4) 62 (9.8) 62 (9.8)
128 121 (10.0) 114 (10.6) 124 (9.7) 115 (10.5)
256 238 (10.1) 209 (11.6) 238 (10.1) 227 (10.7)
512 460 (10.5) 391 (12.4) 452 (10.7) 472 (10.2)
1024 893 (10.8) 757 (12.8) 791 (12.2) 868 (11.1)
2048 1,330 (14.5) 1,525 (12.7) 859 (22.5) 1,411 (13.7)
4096 1,830 (21.1) 1,908 (20.3) 1,145 (33.8) 2,051 (18.8)

down_proj, TP=4a[m,768] @ b[768,6144] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 2 (5.2) 2 (5.3) 2 (5.0) 2 (5.1)
2 4 (5.2) 4 (4.9) 4 (5.0) 4 (5.1)
4 7 (5.2) 7 (5.3) 8 (5.0) 7 (5.1)
8 15 (5.2) 14 (5.3) 15 (5.0) 16 (4.6)
16 29 (5.2) 29 (5.3) 30 (5.0) 33 (4.5)
32 59 (5.2) 54 (5.6) 60 (5.0) 64 (4.7)
64 117 (5.2) 108 (5.6) 117 (5.2) 121 (5.0)
128 219 (5.5) 209 (5.8) 219 (5.5) 244 (5.0)
256 391 (6.2) 424 (5.7) 333 (7.3) 422 (5.7)
512 640 (7.6) 803 (6.0) 519 (9.3) 668 (7.2)
1024 974 (9.9) 1,203 (8.0) 721 (13.4) 1,144 (8.4)
2048 1,448 (13.3) 1,611 (12.0) 934 (20.7) 1,502 (12.9)
4096 1,977 (19.6) 1,879 (20.6) 1,091 (35.4) 1,958 (19.7)

Single GEMM — n=8960, k=6144 (no TP split)

n8960_k6144, TP=1a[m,6144] @ b[6144,8960] (TFLOP/s, kernel µs in parens; bold = fastest)

m cutlass cudnn trtllm cute-dsl
1 5 (20.4) 6 (19.8) 7 (15.8) 6 (19.2)
2 11 (20.2) 11 (19.5) 14 (15.8) 12 (18.9)
4 22 (20.2) 22 (19.6) 28 (15.7) 23 (18.9)
8 44 (20.2) 46 (19.4) 56 (15.8) 59 (14.8)
16 87 (20.2) 90 (19.6) 115 (15.3) 117 (15.0)
32 216 (16.3) 167 (21.1) 228 (15.4) 232 (15.2)
64 430 (16.4) 383 (18.4) 433 (16.3) 446 (15.8)
128 847 (16.6) 686 (20.5) 775 (18.2) 858 (16.4)
256 1,498 (18.8) 1,671 (16.9) 934 (30.2) 1,604 (17.6)
512 2,297 (24.5) 2,495 (22.6) 1,096 (51.4) 2,279 (24.7)
1024 2,698 (41.8) 2,803 (40.2) 1,234 (91.4) 2,710 (41.6)
2048 2,738 (82.4) 2,666 (84.6) 1,430 (157.7) 2,719 (82.9)
4096 2,949 (152.9) 2,840 (158.8) 1,520 (296.7) 2,978 (151.4)

Test Plan

Verified gsm8k score on minimax-m3

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
@mergify mergify Bot added the nvidia label Jun 22, 2026
@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 22, 2026
@mgoin

mgoin commented Jun 22, 2026

Copy link
Copy Markdown
Member

@zyongye do we need to consider any warmup for cutedsl kernel, or it will be handled by flashinfer autotuning?

Comment on lines +159 to +164
if not weight.is_contiguous():
weight = weight.contiguous()

output = vllm_flashinfer.mm_mxfp8(
input_mxfp8,
weight.t(),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well move the weight = weight.contiguous() and the final weight.t() into process_weights_after_loading above

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — moved contiguous() + transpose into process_weights_after_loading (weight is now stored column-major [K, N]), so apply_weights passes it straight to mm_mxfp8 with no per-forward work.

Comment on lines +104 to +105
cap = current_platform.get_device_capability()
if cap is None or cap.to_int() not in (100, 103):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should check cuda, could be just current_platform.is_cuda() and current_platform.is_device_capability_family(100)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — switched to current_platform.is_cuda() and current_platform.is_device_capability_family(100). The family check is (cap // 10) == 10, so it matches sm_100/sm_103 and excludes sm_110/120/121, which is what cute-dsl supports.

@mergify

mergify Bot commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 22, 2026
@zyongye

zyongye commented Jun 23, 2026

Copy link
Copy Markdown
Member Author

@zyongye do we need to consider any warmup for cutedsl kernel, or it will be handled by flashinfer autotuning?

I think FI autotune handles it.

Add an MXFP8 W8A8 linear GEMM that drives FlashInfer's
mm_mxfp8(..., backend="cute-dsl"), sibling to the existing CUTLASS kernel.
The cute-dsl backend consumes the same 1D swizzled F8_128x4 scales the
CUTLASS path already produces, so weight/activation prep is identical and
output is bit-identical; only the backend string and support gate differ.

Gate to sm_100/sm_103 (matching FlashInfer's supported_compute_capability
[100, 103]) plus has_flashinfer_cutedsl(); auto-selects on those archs and
falls through to CUTLASS elsewhere. Reachable explicitly via
--linear-backend flashinfer_cutedsl.

Verified on SM100: bit-identical parity vs the CUTLASS kernel, correct
backend dispatch, and tests/models/quantization/test_mxfp8.py
test_mxfp8_generation[dense] passing (cute-dsl by default).

AI assistance (Claude Code) was used for this change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <yongye@inferact.ai>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the feat/mxfp8_cutedsl_linear branch from fc7d7d5 to 2bd8a0c Compare June 23, 2026 04:52
@mergify mergify Bot removed the needs-rebase label Jun 23, 2026
@zyongye zyongye added this to the v0.24.0 cherrypick milestone Jun 23, 2026
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Jun 23, 2026
@vllm-bot vllm-bot merged commit 11b56b2 into vllm-project:main Jun 23, 2026
86 of 87 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 23, 2026
khluu pushed a commit that referenced this pull request Jun 24, 2026
…46393)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
(cherry picked from commit 11b56b2)
nkzhenhua pushed a commit to nkzhenhua/vllm that referenced this pull request Jun 24, 2026
qli88 pushed a commit to qli88/vllm that referenced this pull request Jun 26, 2026
…llm-project#46393)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Qiang Li <qiang.li2@amd.com>
wincent8 pushed a commit to wincent8/vllm that referenced this pull request Jun 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

3 participants