[Kernel] Add FlashInferCutedslMxfp8LinearKernel (cute-dsl mm_mxfp8)#46393
Conversation
|
@zyongye do we need to consider any warmup for cutedsl kernel, or it will be handled by flashinfer autotuning? |
| if not weight.is_contiguous(): | ||
| weight = weight.contiguous() | ||
|
|
||
| output = vllm_flashinfer.mm_mxfp8( | ||
| input_mxfp8, | ||
| weight.t(), |
There was a problem hiding this comment.
Might as well move the weight = weight.contiguous() and the final weight.t() into process_weights_after_loading above
There was a problem hiding this comment.
Done — moved contiguous() + transpose into process_weights_after_loading (weight is now stored column-major [K, N]), so apply_weights passes it straight to mm_mxfp8 with no per-forward work.
| cap = current_platform.get_device_capability() | ||
| if cap is None or cap.to_int() not in (100, 103): |
There was a problem hiding this comment.
nit: we should check cuda, could be just current_platform.is_cuda() and current_platform.is_device_capability_family(100)
There was a problem hiding this comment.
Done — switched to current_platform.is_cuda() and current_platform.is_device_capability_family(100). The family check is (cap // 10) == 10, so it matches sm_100/sm_103 and excludes sm_110/120/121, which is what cute-dsl supports.
|
This pull request has merge conflicts that must be resolved before it can be |
I think FI autotune handles it. |
Add an MXFP8 W8A8 linear GEMM that drives FlashInfer's mm_mxfp8(..., backend="cute-dsl"), sibling to the existing CUTLASS kernel. The cute-dsl backend consumes the same 1D swizzled F8_128x4 scales the CUTLASS path already produces, so weight/activation prep is identical and output is bit-identical; only the backend string and support gate differ. Gate to sm_100/sm_103 (matching FlashInfer's supported_compute_capability [100, 103]) plus has_flashinfer_cutedsl(); auto-selects on those archs and falls through to CUTLASS elsewhere. Reachable explicitly via --linear-backend flashinfer_cutedsl. Verified on SM100: bit-identical parity vs the CUTLASS kernel, correct backend dispatch, and tests/models/quantization/test_mxfp8.py test_mxfp8_generation[dense] passing (cute-dsl by default). AI assistance (Claude Code) was used for this change. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Yongye Zhu <yongye@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
fc7d7d5 to
2bd8a0c
Compare
…llm-project#46393) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…llm-project#46393) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Qiang Li <qiang.li2@amd.com>
…llm-project#46393) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Add an MXFP8 W8A8 linear GEMM that drives FlashInfer's mm_mxfp8(..., backend="cute-dsl"), sibling to the existing CUTLASS kernel. The cute-dsl backend consumes the same 1D swizzled F8_128x4 scales the CUTLASS path already produces, so weight/activation prep is identical and output is bit-identical; only the backend string and support gate differ.
Gate to sm_100/sm_103 (matching FlashInfer's supported_compute_capability [100, 103]) plus has_flashinfer_cutedsl(); auto-selects on those archs and falls through to CUTLASS elsewhere. Reachable explicitly via --linear-backend flashinfer_cutedsl.
Purpose
Perf comparison between various mxfp8 linear backend
MLP grid — H=6144, I=3072 (TP shards intermediate)
up_proj, TP=1 —
a[m,6144] @ b[6144,3072](TFLOP/s, kernel µs in parens; bold = fastest)up_proj, TP=2 —
a[m,6144] @ b[6144,1536](TFLOP/s, kernel µs in parens; bold = fastest)up_proj, TP=4 —
a[m,6144] @ b[6144,768](TFLOP/s, kernel µs in parens; bold = fastest)up_proj, TP=8 —
a[m,6144] @ b[6144,384](TFLOP/s, kernel µs in parens; bold = fastest)down_proj, TP=1 —
a[m,3072] @ b[3072,6144](TFLOP/s, kernel µs in parens; bold = fastest)down_proj, TP=2 —
a[m,1536] @ b[1536,6144](TFLOP/s, kernel µs in parens; bold = fastest)down_proj, TP=4 —
a[m,768] @ b[768,6144](TFLOP/s, kernel µs in parens; bold = fastest)down_proj, TP=8 —
a[m,384] @ b[384,6144](TFLOP/s, kernel µs in parens; bold = fastest)TP=4 num_tokens sweep (m = 1…4096)
up_proj, TP=4 —
a[m,6144] @ b[6144,768](TFLOP/s, kernel µs in parens; bold = fastest)down_proj, TP=4 —
a[m,768] @ b[768,6144](TFLOP/s, kernel µs in parens; bold = fastest)Single GEMM — n=8960, k=6144 (no TP split)
n8960_k6144, TP=1 —
a[m,6144] @ b[6144,8960](TFLOP/s, kernel µs in parens; bold = fastest)Test Plan
Verified gsm8k score on minimax-m3
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.