Skip to content

[ROCm][Compile] Fuse AR + RMSNorm + per-group FP8 quant (+ DSv3.2 indexer fan-out)#42864

Merged
tjtanaa merged 15 commits into
vllm-project:mainfrom
maeehart:mahartik/rocm-aiter-ar-rms-quant-fusion
Jun 9, 2026
Merged

[ROCm][Compile] Fuse AR + RMSNorm + per-group FP8 quant (+ DSv3.2 indexer fan-out)#42864
tjtanaa merged 15 commits into
vllm-project:mainfrom
maeehart:mahartik/rocm-aiter-ar-rms-quant-fusion

Conversation

@maeehart

@maeehart maeehart commented May 17, 2026

Copy link
Copy Markdown
Contributor

Summary

Two new patterns in RocmAiterAllReduceFusionPass that fuse the all_reduce -> RMSNorm[+add] -> per-group FP8 quant [-> bf16 indexer GEMM] chain into one AITER call, eliminating ~535 us / decode step of standalone triton_per_token_group_quant_fp8 launches on DeepSeek V3.2 MI355X TP4.

This is the AR-side analogue of #41825 (which fixed the quant-only half) and the ROCm port of the flashinfer AllReduceFusedRMSNormStaticQuantFP8Pattern family already in this file for NVIDIA.

Background

DSv3.2 (and other FP8-blockwise models) end every transformer block with

all_reduce -> [fused_add_]rms_norm -> rocm_aiter_group_fp8_quant -> fp8_gemm

PR #41825 fixed the quant-side half (RocmAiterRMSNormQuantFusionPass) so a plain rms_norm -> group_fp8_quant pair is rewritten as rocm_aiter_rmsnorm_fp8_group_quant. But that pass cannot fire once RocmAiterAllReduceFusionPass has already absorbed the rms_norm consumer of the all_reduce: at that point the FX graph just sees an opaque rocm_aiter_fused_allreduce_rmsnorm producing bf16 and the standalone rocm_aiter_group_fp8_quant consumer stays unfused.

On DSv3.2 TP4 MI355X that leaves the standalone per-token-group quant launches at ~535 us / decode step (61 layers x ~8.8 us each).

What this PR does

Commit 1 ([ROCm][Compile] Fuse trailing per-group FP8 quant into AITER AR+RMSNorm):

Commit 2 ([ROCm][Compile] Fuse AR+RMSNorm+FP8 quant + DSv3.2 indexer fan-out):

The above only fires when the RMSNorm output has a single consumer. In DSv3.2 it has two: the FP8 group-quant (for fused_qkv_a_proj) and a separate bf16 rocm_unquantized_gemm (for the indexer wk_weights_proj). With the fan-out present, the pattern matcher falls back to the no-quant fusion and the quant kernel survives.

This commit closes that gap:

  • Adds rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm -- the same AITER launcher with emit_bf16=True, returning a 4-tuple (FP8 quant, residual, per-group scale, bf16 normed activations). The launcher already supported this; this just exposes it as a custom op.

  • Adds AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern. Matches AR -> fused_add_rms_norm.default -> (per-group FP8 quant + rocm_unquantized_gemm) directly from the actual DSv3.2 post-grad FX dump (note: .default, not .maybe_inplace; the AR fusion pass runs before VllmIRInplaceFunctionalizationPass). The replacement reuses the fused op's emitted bf16 norm output for the indexer GEMM.

  • The RMSNorm output is also a graph output (it feeds the next compiled chunk as the residual carry), so the pattern returns it as a fifth value and the replacement substitutes the fused op's bf16 output in the same position. The pattern matcher then rewires all external uses automatically.

  • Registered before the single-consumer quant pattern so the larger subgraph is preferred when an indexer GEMM is present.

Commit 3 ([Format] ruff format AR+RMSNorm fusion files): mechanical -- collapses a few unnecessary parenthesized line wraps in the new code so the files pass ruff format --check. No behavior change.

Commit 4 ([Review] Also register AITER-form variant of indexer fan-out pattern): the indexer-fan-out pattern is now registered twice via a use_triton_quant flag. The Triton-quant form matches what DSv3.2's fused_qkv_a_proj actually emits at the fan-out site today (via QuantFP8.forward_hip); the MatcherQuantFP8 form matches sites where the producer routes through rocm_aiter_group_fp8_quant. Without both, FX dumps showed the new pattern firing at only one variant of the call site. Per Gemini reviewer ask for consistency with the sibling AR+RMS+QUANT patterns.

Commit 5 ([Review] Tighten indexer fan-out pattern docs and registration): docstring/comment cleanup only -- trim the new pattern's docstring to match the sibling AR+RMS+QUANT patterns' style and replace the two near-identical register(...) calls with a small loop. No behavior change.

Validation

DSv3.2 TP4 on MI355X, vllm/vllm-openai-rocm:nightly-32b7177909, AITER with ROCm/aiter#2823 wrapper patches.

FX dump diff at the dominant post-attn site (__compiled_fn_1.post_grad.1.rocm_aiter_allreduce_fusion_pass.after.1.py):

op main this PR
standalone triton_per_token_group_quant_fp8 4 0
fallback rocm_aiter_fused_allreduce_rmsnorm.default 3 0
new ..._quant_per_group_with_bf16_norm 0 5
rocm_unquantized_gemm.default (indexer) 3 3 (now reads bf16_norm from fused op)

Same picture in chunk 2 (4 stranded quants -> 0, +5 new fused ops). Chunks 0 and 3 unchanged because they don't have the fused_add + quant + indexer site.

Serving perf (vllm bench serve, ISL=1000, OSL=100, 3-4 seeds each). "main" = upstream main with no patterns from this PR; "patched, commit 1 only" = the AR+RMS+QUANT half (single-consumer pattern, no indexer fan-out); "patched, full PR" = all five commits.

config main patched, commit 1 only patched, full PR
mc=4 np=32 mean TPOT 17.69 +/- 0.39 ms 17.93 +/- 0.29 ms 16.96 +/- 0.33 ms
mc=8 np=64 mean TPOT 21.42 +/- 0.34 ms 20.60 +/- 0.38 ms 20.80 +/- 0.27 ms
mc=16 np=64 mean TPOT 28.04 +/- 0.52 ms 28.81 +/- 0.93 ms 26.59 +/- 0.50 ms
mc=16 np=64 throughput 4644 tok/s 4537 tok/s 4945 tok/s

i.e. -7.7% TPOT, +9.0% throughput at mc=16 for the full PR vs main, and the indexer-fan-out commit (commit 2 onwards) is doing the lion's share of the work -- commit 1 alone shifts perf only marginally (and at mc=4/mc=16 not in a useful direction). End-to-end output remains coherent (capital of France, simple arithmetic chain).

Notes / scope

  • The fusion pass only fires when rocm_aiter_ops.has_fused_allreduce_rmsnorm_quant_per_group() is true (i.e. AITER ships the fused_ar_rms_per_group_quant launcher); older AITER builds fall back silently to the existing AR+RMS-only fusion.

  • Chunk 0 of the DSv3.2 compiled graph still has 4 standalone Triton quants -- that's the first-layer (no-fused_add) indexer fan-out site (~1/61 of the per-step savings). A sibling AiterAllreduceFusedRMSNormGroupQuantWithIndexerPattern (no-Add variant) is a natural follow-up but not in this PR.

  • This PR depends on Add fused AR + RMSNorm + per-group FP8 quant: optional bf16 side-output ROCm/aiter#2823 being available for the Python wrappers; without it the new ops silently fall back at the has_* probe.

Test plan

  • CI green on the ROCm path
  • Maintainer confirms the FX trace pattern from the description matches their compiled DSv3.2 graph (or shares a counter-example we can extend the pattern for)
  • One more independent decode-step perf measurement on a different MI355X box

Made with Cursor

@mergify mergify Bot added the rocm Related to AMD ROCm label May 17, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 17, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several new fusion patterns to the RocmAiterAllReduceFusionPass that combine AllReduce, RMSNorm, and per-group FP8 quantization into single AITER kernels, specifically targeting performance optimizations for DeepSeek V3.2. It also implements the necessary ROCm AITER operation bindings and registration logic. The review feedback recommends refactoring the AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern to utilize MatcherQuantFP8 for quantization matching, ensuring consistency with other patterns and improving robustness against varying quantization op representations.

Comment thread vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Comment on lines +1226 to +1246
def pattern(self):
eps = self.epsilon
gs = self.group_size

def _pattern(
residual: torch.Tensor,
input_: torch.Tensor,
norm_weight: torch.Tensor,
indexer_weight: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
ar_out = tensor_model_parallel_all_reduce(input_)
rms, res_out = vllm.ir.ops.fused_add_rms_norm(
ar_out, residual, norm_weight, eps
)
q, s = torch.ops.vllm.triton_per_token_group_quant_fp8(rms, gs)
idx = torch.ops.vllm.rocm_unquantized_gemm(rms, indexer_weight)
return q, s, res_out, idx, rms

return _pattern

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the quant_matcher initialized in __init__ to make the pattern more flexible and consistent with AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern.

    @property
    def pattern(self):
        eps = self.epsilon

        def _pattern(
            residual: torch.Tensor,
            input_: torch.Tensor,
            norm_weight: torch.Tensor,
            indexer_weight: torch.Tensor,
        ) -> tuple[
            torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
        ]:
            ar_out = tensor_model_parallel_all_reduce(input_)
            rms, res_out = vllm.ir.ops.fused_add_rms_norm(
                ar_out, residual, norm_weight, eps
            )
            q, s = self.quant_matcher(rms)
            idx = torch.ops.vllm.rocm_unquantized_gemm(rms, indexer_weight)
            return q, s, res_out, idx, rms

        return _pattern

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the reply on the __init__ thread -- using self.quant_matcher(rms) here empirically takes the fusion from 5 matches per chunk to 0 because the matcher only traces the AITER form and the actual DSv3.2 site emits the Triton form. Addressed by registering both Triton-form and AITER-form variants of the pattern (e4aab1e).

maeehart added a commit to maeehart/vllm that referenced this pull request May 17, 2026
Per gemini-code-assist suggestion on vllm-project#42864:

The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses
`MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of
the FP8 group-quant op. The original `WithIndexer` pattern only matched
the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's
`fused_qkv_a_proj` actually emits today (verified by FX dump).

Empirically: replacing the direct triton-op call with `MatcherQuantFP8`
takes the indexer fusion from 5 matches per chunk to 0 -- the matcher
only registers one form per instance, and `match_rocm_aiter=True`
picks the AITER form. So a single-variant pattern can't cover both.

Solution: keep the Triton variant (which catches today's DSv3.2 sites)
and additionally register an AITER-form variant via `MatcherQuantFP8`
for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both
lower to the same 4-output fused op + standalone indexer GEMM, so the
replacement is identical.

`use_triton_quant: bool` toggles between the two; registered with both
values in `RocmAiterAllReduceFusionPass`.

Validation on the same DSv3.2 TP4 MI355X rig:
  Chunks 1 + 2 still each show 0 standalone triton quants and 5 new
  `..._with_bf16_norm` fused ops (same as the single-variant version).
  Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT,
  2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms,
  2161 tok/s. The AITER variant adds zero runtime cost on this model
  because it doesn't fire; it's safety for models that route through
  AITER's group-quant path.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
maeehart added a commit to maeehart/vllm that referenced this pull request May 17, 2026
Per gemini-code-assist suggestion on vllm-project#42864:

The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses
`MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of
the FP8 group-quant op. The original `WithIndexer` pattern only matched
the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's
`fused_qkv_a_proj` actually emits today (verified by FX dump).

Empirically: replacing the direct triton-op call with `MatcherQuantFP8`
takes the indexer fusion from 5 matches per chunk to 0 -- the matcher
only registers one form per instance, and `match_rocm_aiter=True`
picks the AITER form. So a single-variant pattern can't cover both.

Solution: keep the Triton variant (which catches today's DSv3.2 sites)
and additionally register an AITER-form variant via `MatcherQuantFP8`
for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both
lower to the same 4-output fused op + standalone indexer GEMM, so the
replacement is identical.

`use_triton_quant: bool` toggles between the two; registered with both
values in `RocmAiterAllReduceFusionPass`.

Validation on the same DSv3.2 TP4 MI355X rig:
  Chunks 1 + 2 still each show 0 standalone triton quants and 5 new
  `..._with_bf16_norm` fused ops (same as the single-variant version).
  Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT,
  2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms,
  2161 tok/s. The AITER variant adds zero runtime cost on this model
  because it doesn't fire; it's safety for models that route through
  AITER's group-quant path.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
@maeehart maeehart force-pushed the mahartik/rocm-aiter-ar-rms-quant-fusion branch from e4aab1e to 036383c Compare May 17, 2026 09:03
@Rohan138

Copy link
Copy Markdown
Contributor

@maeehart Can you add a unit test to https://github.com/vllm-project/vllm/blob/main/tests/compile/passes/distributed/test_fusion_all_reduce.py similar to the existing ones for AR+RMS+FP8 static quant?

@maeehart

Copy link
Copy Markdown
Contributor Author

@Rohan138 Done in ac15067.

Added test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace to tests/compile/passes/distributed/test_fusion_all_reduce.py, modeled on the existing TestAllReduceRMSNormStaticQuantFP8Model test you pointed at.

It validates all three new VllmPatternReplacement classes this PR registers on RocmAiterAllReduceFusionPass:

  • AiterAllreduceFusedRMSNormGroupQuantFP8Pattern (no-residual)
  • AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern (with-residual, single rms consumer)
  • AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern (with-residual, DSv3.2 indexer fan-out)

The new TestAiterAllReduceRMSNormGroupQuantFP8Model has four rms_norm sites laid out so each one hits a different pattern:

  • norm[0]: no residual -> simpler no-add pattern
  • norm[1]: with residual, single rms consumer -> simpler add pattern
  • norm[2..3]: with residual + a second rocm_unquantized_gemm consumer -> indexer-fan-out pattern

Asserts matched_count == 4, that all_reduce.default + the chosen group-quant op disappear, and that both rocm_aiter_fused_allreduce_rmsnorm_quant_per_group(.default) and ..._with_bf16_norm(.default) appear post-fusion. Also numerical parity via torch.testing.assert_close(..., atol=1e-2, rtol=1e-2) (same tolerance as the sibling static-FP8 test).

Parametrized over both triton_per_token_group_quant_fp8 and rocm_aiter_group_fp8_quant producers (the two use_triton_quant=True/False registrations added in commit 4) and over +rms_norm custom op on/off.

Validated on 2x MI355X (TP=2) with the PR 42864 docker image + ROCm/aiter#2823 wrapper patches: all 4 combos PASS in ~32 s. (CI green for these on AMD will require the build to ship a fused_ar_rms_per_group_quant-capable aiter; the test pytest.skips gracefully otherwise.)

@dllehr-amd dllehr-amd added the ready ONLY add when PR is ready to merge/full CI is needed label May 20, 2026
@maeehart

maeehart commented May 21, 2026

Copy link
Copy Markdown
Contributor Author

Picked up Tres's dtype fix (5da2f75) + parity for the 3 new patterns

Cherry-picked tpopp/aiter-ar-rmsnorm-weight-dtype into this PR with credit, and extended the same weight.to(input.dtype) cast to the three new AR+RMS+per-group-FP8-quant patterns this PR adds (AiterAllreduceFusedRMSNormGroupQuantFP8Pattern, AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern, AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern) — they all forward weight straight into the same aiter_ar.fused_ar_rms_per_group_quant kernel and inherit the same dtype constraint, so it would be a regression waiting to happen on any future model whose RMSNorm weight loads as fp32 with bf16 activations (the Qwen3-Next case).

Validated on an MI355X (gfx950) box at TP=2 (the failing config from the reproduction) with --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["-rms_norm","-silu_and_mul","-quant_fp8"],"pass_config":{"fuse_norm_quant":true,"fuse_allreduce_rms":...}}', patched file bind-mounted over the in-image copy.

Accuracy (lm_eval gsm8k 5-shot, --limit 200, Qwen3-Next-80B-A3B-Instruct-FP8)

Config strict-match flexible-extract
AR+RMS fusion ON, with dtype fix 0.925 ±0.019 0.965 ±0.013
AR+RMS fusion OFF (reference baseline) 0.925 ±0.019 0.955 ±0.015

Strict-match is bit-identical between fusion-on and fusion-off; flexible-extract is within half a stderr. The previously-broken TP=2 + AR+RMS-on path is now indistinguishable from the AR+RMS-off baseline.

Perf (vllm bench serve, random, ISL=1000, OSL=100, TP=2)

Conc Total Tput (AR+RMS ON, with fix) Total Tput (AR+RMS OFF) ON vs OFF
16 13773.0 13494.2 +2.1 %
32 21685.7 21711.1 -0.1 %
64 29987.3 30683.1 -2.3 %

The cast is weight.to(input.dtype) on a graph-input weight parameter, so inductor constant-folds it at compile time — no runtime cost. The ON-vs-OFF perf delta is in single-run noise.

I'm leaving the indexer-fan-out pattern in place (it was an earlier suspect before the dtype diagnosis came in). It's guarded by rms being returned as a pattern output and was not the proximate cause of the Qwen3-Next regression; on DSv3.2 it still earns its ~535 µs/step.

Comment thread vllm/compilation/passes/fusion/allreduce_rms_fusion.py
maeehart added a commit to maeehart/vllm that referenced this pull request May 22, 2026
Per gemini-code-assist suggestion on vllm-project#42864:

The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses
`MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of
the FP8 group-quant op. The original `WithIndexer` pattern only matched
the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's
`fused_qkv_a_proj` actually emits today (verified by FX dump).

Empirically: replacing the direct triton-op call with `MatcherQuantFP8`
takes the indexer fusion from 5 matches per chunk to 0 -- the matcher
only registers one form per instance, and `match_rocm_aiter=True`
picks the AITER form. So a single-variant pattern can't cover both.

Solution: keep the Triton variant (which catches today's DSv3.2 sites)
and additionally register an AITER-form variant via `MatcherQuantFP8`
for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both
lower to the same 4-output fused op + standalone indexer GEMM, so the
replacement is identical.

`use_triton_quant: bool` toggles between the two; registered with both
values in `RocmAiterAllReduceFusionPass`.

Validation on the same DSv3.2 TP4 MI355X rig:
  Chunks 1 + 2 still each show 0 standalone triton quants and 5 new
  `..._with_bf16_norm` fused ops (same as the single-variant version).
  Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT,
  2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms,
  2161 tok/s. The AITER variant adds zero runtime cost on this model
  because it doesn't fire; it's safety for models that route through
  AITER's group-quant path.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
maeehart added a commit to maeehart/vllm that referenced this pull request May 22, 2026
…erQuantFP8

Per @ProExpertProg's review on vllm-project#42864: ``MatcherQuantFP8`` now traces
both ``QuantFP8.forward_hip`` and ``forward_native`` paths, so the
indexer-fan-out pattern no longer needs to register two variants of
itself (one matching ``triton_per_token_group_quant_fp8`` directly and
one matching ``rocm_aiter_group_fp8_quant``).

Drop the ``use_triton_quant`` constructor flag and the corresponding
double-registration in ``RocmAiterAllReduceFusionPass``. The single
``MatcherQuantFP8(group_size=128, match_rocm_aiter=True)`` instance now
matches both call sites; ``test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace``
already parametrizes over ``use_triton_quant`` on the synthetic model
side, so unit-test coverage is unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>
@maeehart maeehart force-pushed the mahartik/rocm-aiter-ar-rms-quant-fusion branch from 5da2f75 to 2e1c457 Compare May 22, 2026 08:08
@maeehart

Copy link
Copy Markdown
Contributor Author

@ProExpertProg good call -- I rebased on main (specifically commit 1fe330398 from vllm/vllm-openai-rocm:nightly-1fe3303983e1829fae25edfb0b93e8cbcfad96e6) and dropped the dual use_triton_quant=True/False registration on the indexer-fan-out pattern. The single MatcherQuantFP8(group_shape=(1, 128), match_rocm_aiter=True) instance now matches both triton_per_token_group_quant_fp8 and rocm_aiter_group_fp8_quant, so the refactor halves the registration in RocmAiterAllReduceFusionPass for the indexer fan-out.

New commit on top of the existing series (HEAD = 2e1c457923):

[Review] Drop use_triton_quant flag in indexer pattern; rely on MatcherQuantFP8

Validation

The same tests/compile/passes/distributed/test_fusion_all_reduce.py::test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace[use_triton_quant=True/False] parametrization exercises both call sites against the synthetic graph, and both pass. End-to-end I ran on AMD Instinct (TP=2, public nightly + this PR overlaid):

Mode gsm8k flexible-extract gsm8k strict-match
Fusion on (fuse_norm_quant=true, fuse_allreduce_rms=true) 0.960 ± 0.0139 0.925 ± 0.0187
Fusion off (fuse_allreduce_rms=false) 0.960 ± 0.0139 0.930 ± 0.0181

Identical within stderr -- the matcher consolidation does not change which patterns fire on Qwen3-Next-80B FP8.

vllm bench serve (random ISL=1000, OSL=100, Qwen3-Next-80B-FP8, TP=2) shows the configurations within ~3% of each other across conc=16/32/64, which is at the noise floor for the MoE-dominated forward pass on this image -- the AITER fused_ar_rms_per_group_quant per-group kernel doesn't ship in this public-image aiter build, so the new indexer-fan-out fusion correctly self-disables (logged by the existing warning at allreduce_rms_fusion.py:1359) and only the AR+RMS+per-tensor fusion fires for this model. The DSv3.2 numbers for the per-group-quant indexer path remain those reported earlier in this PR's testing log.

@maeehart

Copy link
Copy Markdown
Contributor Author

Quick correction on the previous comment: those DSv3.2 numbers I alluded to ("remain those reported earlier in this PR's testing log") are the ones already in the PR description, not numbers I re-measured against the rebased + refactored branch. I should have been explicit -- I did not re-run DSv3.2 in this session.

Reasoning for not re-running:

So how I'd frame the validation now:

  • Qwen3-Next-80B (this session): non-regression sanity check that the rebase + MatcherQuantFP8 consolidation preserves accuracy and the existing per-tensor AR+RMS fusion (gsm8k 0.960/0.925 vs 0.960/0.930). On this model the indexer fan-out pattern is not present and the per-group fusion is gated off, so this only exercises the matcher rewiring, not the new-op replacement.
  • DSv3.2 TP4 MI355X (PR description): the -7.7% TPOT / +9.0% throughput numbers stand. The refactor in 2e1c457923 only deduplicates registration of the indexer-fan-out pattern -- the matched and replaced ops are identical to the pre-refactor branch -- so the DSv3.2 fusion behaviour is structurally unchanged. Those numbers will be re-confirmable on a public image once Add fused AR + RMSNorm + per-group FP8 quant: optional bf16 side-output ROCm/aiter#2823 ships in aiter and a vLLM nightly bumps to it.
maeehart and others added 7 commits May 25, 2026 10:12
DSv3.2 (and other FP8-blockwise models) end every transformer block with

    all_reduce -> [fused_add_]rms_norm -> rocm_aiter_group_fp8_quant -> fp8_gemm

PR vllm-project#41825 fixed the quant-side half of this (`RocmAiterRMSNormQuantFusionPass`)
so a plain `rms_norm -> group_fp8_quant` pair is rewritten as
`rocm_aiter_rmsnorm_fp8_group_quant` (one HIP call). But that pass cannot fire
when the `rms_norm` consumer of the all_reduce has already been absorbed by
`RocmAiterAllReduceFusionPass`: at that point the FX graph just sees an opaque
`rocm_aiter_fused_allreduce_rmsnorm` producing bf16 and the standalone
`rocm_aiter_group_fp8_quant` consumer stays unfused. On a DSv3.2 TP4 decode
trace that leaves ~535us / step of `dynamic_per_group_scaled_quant` launches
(122 calls per step at ~4.4us each).

This change extends `RocmAiterAllReduceFusionPass` with two new patterns that
match the full `AR -> RMSNorm[+add] -> group_fp8_quant` chain and rewrite it
into a single `rocm_aiter_fused_allreduce_rmsnorm_quant_per_group` op backed
by AITER's `fused_ar_rms_per_group_quant` launcher (ROCm/aiter PR vllm-project#2823).

Mechanics:

- `vllm/_aiter_ops.py`: add the custom op binding plus the
  `AiterCustomAllreduceProto` member, a fake impl returning the FP8 quant
  tensor + per-group scale (`(M, hidden/group_size)` float32), and a feature-
  probe `has_fused_allreduce_rmsnorm_quant_per_group` so callers can degrade
  cleanly on older aiter builds.

- `vllm/compilation/passes/fusion/allreduce_rms_fusion.py`: two new
  `VllmPatternReplacement` patterns (`AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`
  and the `fused_add` sibling) wired into `RocmAiterAllReduceFusionPass`.
  Both reuse `MatcherQuantFP8` to be insensitive to whether `quant_fp8` is
  enabled as a custom op (same approach as PR vllm-project#41825). The quant variants
  register before the non-quant variants so the matcher prefers them whenever
  a downstream group quant exists; the existing AR+RMS-only patterns still
  match for the AR sites that lack a trailing quant (e.g. final block).

- The pass keys off `rocm_aiter_ops.has_fused_allreduce_rmsnorm_quant_per_group()`
  so an aiter build without PR vllm-project#2823 silently falls back to the existing
  AR+RMS-only fusion (correct but slower).

This is the AR-side analogue of PR vllm-project#41825 and the ROCm port of the
flashinfer `AllReduceFusedRMSNormStaticQuantFP8Pattern` family that already
exists in the same file for the NVIDIA path.

Validation pending: needs DSv3.2 TP4 bf16-rms-bf16-input FP8-blockwise smoke
on MI355X to confirm pattern count and end-to-end serving accuracy parity.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
DeepSeek V3.2 ends every transformer block but the first with

    all_reduce -> fused_add_rms_norm -> {
        per-token-group FP8 quant -> fp8 block-scaled gemm   (fused_qkv_a_proj),
        rocm_unquantized_gemm (bf16)                          (indexer wk_weights_proj),
    }

The bf16 indexer GEMM is the blocker: the existing
`AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` only matches when the
RMSNorm output is consumed by a single group-quant node, so the trailing
quant kernel survives as a standalone `triton_per_token_group_quant_fp8`
launch (~535 us / decode step on DSv3.2 MI355X TP4).

This change closes that gap.

Mechanics:

- `vllm/_aiter_ops.py`: register
  `rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm` which
  wraps the existing AITER `fused_ar_rms_per_group_quant` launcher with
  `emit_bf16=True`. The launcher already supports the four-output variant;
  this just exposes it as a custom op (impl + fake + helper accessor)
  alongside the three-output op added by the prior commit.

- `vllm/compilation/passes/fusion/allreduce_rms_fusion.py`: new
  `AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`. Matches the
  full `AR -> fused_add_rms_norm.default -> (triton_per_token_group_quant_fp8
  + rocm_unquantized_gemm)` subgraph from the actual DSv3.2 post-grad FX
  dump (note: `.default`, not `.maybe_inplace`; the AR fusion pass runs
  before `VllmIRInplaceFunctionalizationPass`).

  The RMSNorm output is also used outside the matched subgraph (it is a
  graph output that feeds the next compiled chunk as the residual carry),
  so the pattern returns it as a fifth output and the replacement emits
  the fused op's `bf16_norm` output in the same position. Pattern
  matcher's substitution then rewires all external uses to the emitted
  bf16 tensor automatically.

  Registered before the existing AR+RMS+QUANT and AR+RMS-only patterns
  so the larger subgraph is preferred when the indexer fan-out is
  present.

Validation (DSv3.2 TP4 MI355X, vllm-openai-rocm:nightly-32b7177909):

  FX dump (`__compiled_fn_1.post_grad.1.rocm_aiter_allreduce_fusion_pass.
  after.1.py`):
    standalone triton_per_token_group_quant_fp8: 4 -> 0
    fallback rocm_aiter_fused_allreduce_rmsnorm.default: 3 -> 0
    rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm: 0 -> 5
  Same picture in chunk 2 (4 -> 0 stranded quants, +5 new fused ops).
  Chunks 0 and 3 unchanged (no `fused_add + quant + indexer` site).

  Serving perf (`vllm bench serve`, ISL=1000, OSL=100, 3-4 seeds each):
                     a1-baseline    a1-patched (prior)   this commit
    mc=4  np=32     17.69 ms TPOT   17.93 ms TPOT        16.96 ms TPOT
    mc=8  np=64     21.42 ms TPOT   20.60 ms TPOT        20.80 ms TPOT
    mc=16 np=64     28.04 ms TPOT   28.81 ms TPOT        26.59 ms TPOT

    mc=16 throughput: 4644 -> 4537 -> 4945 tok/s
  i.e. -7.7% TPOT / +9.0% throughput at mc=16 vs the prior commit.

  End-to-end output remains coherent (capital of France, simple arithmetic).

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
The two AR+RMSNorm+FP8 quant fusion commits introduced a handful of
unnecessary parenthesized line wraps that `ruff format` collapses back
into single-line form. No behavior change.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Per gemini-code-assist suggestion on vllm-project#42864:

The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses
`MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of
the FP8 group-quant op. The original `WithIndexer` pattern only matched
the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's
`fused_qkv_a_proj` actually emits today (verified by FX dump).

Empirically: replacing the direct triton-op call with `MatcherQuantFP8`
takes the indexer fusion from 5 matches per chunk to 0 -- the matcher
only registers one form per instance, and `match_rocm_aiter=True`
picks the AITER form. So a single-variant pattern can't cover both.

Solution: keep the Triton variant (which catches today's DSv3.2 sites)
and additionally register an AITER-form variant via `MatcherQuantFP8`
for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both
lower to the same 4-output fused op + standalone indexer GEMM, so the
replacement is identical.

`use_triton_quant: bool` toggles between the two; registered with both
values in `RocmAiterAllReduceFusionPass`.

Validation on the same DSv3.2 TP4 MI355X rig:
  Chunks 1 + 2 still each show 0 standalone triton quants and 5 new
  `..._with_bf16_norm` fused ops (same as the single-variant version).
  Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT,
  2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms,
  2161 tok/s. The AITER variant adds zero runtime cost on this model
  because it doesn't fire; it's safety for models that route through
  AITER's group-quant path.

Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
- Trim AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern docstring
  to match the sibling AR+RMS+QUANT patterns' style: keep the why, drop
  the FX-form excerpt and the cross-segment caveat that belongs in a
  passes-level note. Two-variant rationale stays in the docstring.

- Collapse the two near-identical register(...) calls in
  RocmAiterAllReduceFusionPass into a small loop over the boolean flag;
  delete the inline comment that duplicates the docstring.

No behavior change.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…erns

Adds ``test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace``
plus a ``TestAiterAllReduceRMSNormGroupQuantFP8Model`` to validate the three
new ``VllmPatternReplacement`` classes this PR registers on
``RocmAiterAllReduceFusionPass``:

* ``AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`` (no-residual)
* ``AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern`` (with-residual,
  single ``rms`` consumer)
* ``AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`` (with-
  residual, DSv3.2 indexer fan-out -- parametrized over both
  ``triton_per_token_group_quant_fp8`` and ``rocm_aiter_group_fp8_quant``
  producers, matching the two registrations added in commit 4)

The model uses four ``rms_norm`` sites that together exercise all three
patterns (matched_count == 4), and chains the FP8 quant output back into
the next AllReduce by manually dequantizing through the per-group scale
so we do not depend on a real FP8 block-scaled GEMM kernel.

Mirrors the structure and assertions of the sibling
``TestAllReduceRMSNormStaticQuantFP8Model`` test as Rohan138 requested
in review.

Validated on 2-GPU MI355X (TP=2, ``smci355-1b112-b13-19``) with the
PR 42864 docker image (aiter v0.1.14-rc0 + ROCm/aiter#2823 wrapper
patches via aiterPR3075): all 4 (``enable_rms_norm_custom_op`` x
``use_triton_quant``) combos PASS in ~32s.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The AITER fused all-reduce kernels (fused_ar_rms and
fused_ar_rms_per_group_quant) take a single opaque pointer for the
norm weight and require its element dtype to equal the activation
dtype. Models such as Qwen3-Next register the RMSNorm weight as
float32 even when the activations are bf16, so the existing
replacements -- both the AR+RMS-only patterns and the three new
AR+RMS+per-group-FP8-quant patterns from this PR -- handed the kernel
a mismatched (fp32 weight, bf16 input) pair that silently corrupted
the normed activations. The downstream effect is poor GSM8K / MTP
accuracy on Qwen3-Next TP2/TP4; TP1 was unaffected only because the
AR+RMS fusion never fires in that case.

Tres Popp identified the root cause and fixed the two existing
patterns in `tpopp/aiter-ar-rmsnorm-weight-dtype`. This commit:

* Picks up Tres's two-line `weight.to(input.dtype)` fix in
  `AiterAllreduceFusedRMSNormPattern` and
  `AiterAllreduceFusedAddRMSNormPattern`, and
* Extends the same cast defensively to the three new AR+RMS+per-group-
  FP8-quant patterns introduced earlier in this PR
  (`AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`,
  `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern`, and
  `AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`), all
  of which forward weight straight into the same
  `aiter_ar.fused_ar_rms_per_group_quant` kernel and thus inherit the
  same dtype constraint.

When weight already matches input.dtype the `.to()` is a no-op (and
constant-folds at compile time for graph-input parameters), so other
models pay nothing.

Diagnosis credit: Tres Popp, Nico Holmberg (reproduction +
confirmation on Qwen3-Next-80B-A3B-Instruct-FP8 TP2).

Co-authored-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
allreduce = self.FUSED_AR_RMSNORM_OP(
input_=input,
residual=residual,
weight=weight,

@Rohan138 Rohan138 May 28, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the type casts here? Edit: just saw Tres' commit adding the cast because aiter expects the rmsnorm weight and input to have the same dtype. Can you check if this cast creates an extra kernel, and if that kernel is fused into the preceding one?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this comment @tpopp is related to a different PR so I am going to mark it as resolve?

Comment thread vllm/_aiter_ops.py
return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm.default # noqa: E501

@classmethod
def has_fused_allreduce_rmsnorm_quant_per_group(cls) -> bool:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO comment to remove this once we bump to aiter 0.1.14 (should be next week)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 2b71459.

… probe

Per Rohan138 on vllm-project#42864: mark the feature probe for removal once vLLM bumps its pinned AITER past v0.1.13.post1 (ROCm/aiter#2823 is expected to ship in v0.1.14).

Signed-off-by: Frida Andersson <fanderss@amd.com>
@zou3519 zou3519 removed their request for review June 2, 2026 14:37

@dllehr-amd dllehr-amd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maeehart and @frida-andersson ! I appreciate the work on this one!

@tjtanaa tjtanaa left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well.

@tjtanaa tjtanaa enabled auto-merge (squash) June 9, 2026 07:55
@tjtanaa tjtanaa merged commit 80e2c44 into vllm-project:main Jun 9, 2026
67 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Jun 9, 2026
ekagra-ranjan pushed a commit to ekagra-ranjan/vllm that referenced this pull request Jun 9, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Saddss pushed a commit to Saddss/vllm that referenced this pull request Jun 14, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
vivek8123 pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Jun 18, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
divineearthly pushed a commit to divineearthly/vllm that referenced this pull request Jun 19, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
Signed-off-by: divineearthly <divineearthly@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Jun 22, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
nkzhenhua pushed a commit to nkzhenhua/vllm that referenced this pull request Jun 24, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
ohsono pushed a commit to ohsono/vllm that referenced this pull request Jul 3, 2026
…exer fan-out) (vllm-project#42864)

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Frida Andersson <fanderss@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

8 participants