[ROCm][Compile] Fuse AR + RMSNorm + per-group FP8 quant (+ DSv3.2 indexer fan-out)#42864
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several new fusion patterns to the RocmAiterAllReduceFusionPass that combine AllReduce, RMSNorm, and per-group FP8 quantization into single AITER kernels, specifically targeting performance optimizations for DeepSeek V3.2. It also implements the necessary ROCm AITER operation bindings and registration logic. The review feedback recommends refactoring the AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern to utilize MatcherQuantFP8 for quantization matching, ensuring consistency with other patterns and improving robustness against varying quantization op representations.
| def pattern(self): | ||
| eps = self.epsilon | ||
| gs = self.group_size | ||
|
|
||
| def _pattern( | ||
| residual: torch.Tensor, | ||
| input_: torch.Tensor, | ||
| norm_weight: torch.Tensor, | ||
| indexer_weight: torch.Tensor, | ||
| ) -> tuple[ | ||
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | ||
| ]: | ||
| ar_out = tensor_model_parallel_all_reduce(input_) | ||
| rms, res_out = vllm.ir.ops.fused_add_rms_norm( | ||
| ar_out, residual, norm_weight, eps | ||
| ) | ||
| q, s = torch.ops.vllm.triton_per_token_group_quant_fp8(rms, gs) | ||
| idx = torch.ops.vllm.rocm_unquantized_gemm(rms, indexer_weight) | ||
| return q, s, res_out, idx, rms | ||
|
|
||
| return _pattern |
There was a problem hiding this comment.
Use the quant_matcher initialized in __init__ to make the pattern more flexible and consistent with AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern.
@property
def pattern(self):
eps = self.epsilon
def _pattern(
residual: torch.Tensor,
input_: torch.Tensor,
norm_weight: torch.Tensor,
indexer_weight: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
ar_out = tensor_model_parallel_all_reduce(input_)
rms, res_out = vllm.ir.ops.fused_add_rms_norm(
ar_out, residual, norm_weight, eps
)
q, s = self.quant_matcher(rms)
idx = torch.ops.vllm.rocm_unquantized_gemm(rms, indexer_weight)
return q, s, res_out, idx, rms
return _patternThere was a problem hiding this comment.
See the reply on the __init__ thread -- using self.quant_matcher(rms) here empirically takes the fusion from 5 matches per chunk to 0 because the matcher only traces the AITER form and the actual DSv3.2 site emits the Triton form. Addressed by registering both Triton-form and AITER-form variants of the pattern (e4aab1e).
Per gemini-code-assist suggestion on vllm-project#42864: The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses `MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of the FP8 group-quant op. The original `WithIndexer` pattern only matched the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's `fused_qkv_a_proj` actually emits today (verified by FX dump). Empirically: replacing the direct triton-op call with `MatcherQuantFP8` takes the indexer fusion from 5 matches per chunk to 0 -- the matcher only registers one form per instance, and `match_rocm_aiter=True` picks the AITER form. So a single-variant pattern can't cover both. Solution: keep the Triton variant (which catches today's DSv3.2 sites) and additionally register an AITER-form variant via `MatcherQuantFP8` for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both lower to the same 4-output fused op + standalone indexer GEMM, so the replacement is identical. `use_triton_quant: bool` toggles between the two; registered with both values in `RocmAiterAllReduceFusionPass`. Validation on the same DSv3.2 TP4 MI355X rig: Chunks 1 + 2 still each show 0 standalone triton quants and 5 new `..._with_bf16_norm` fused ops (same as the single-variant version). Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT, 2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms, 2161 tok/s. The AITER variant adds zero runtime cost on this model because it doesn't fire; it's safety for models that route through AITER's group-quant path. Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Per gemini-code-assist suggestion on vllm-project#42864: The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses `MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of the FP8 group-quant op. The original `WithIndexer` pattern only matched the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's `fused_qkv_a_proj` actually emits today (verified by FX dump). Empirically: replacing the direct triton-op call with `MatcherQuantFP8` takes the indexer fusion from 5 matches per chunk to 0 -- the matcher only registers one form per instance, and `match_rocm_aiter=True` picks the AITER form. So a single-variant pattern can't cover both. Solution: keep the Triton variant (which catches today's DSv3.2 sites) and additionally register an AITER-form variant via `MatcherQuantFP8` for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both lower to the same 4-output fused op + standalone indexer GEMM, so the replacement is identical. `use_triton_quant: bool` toggles between the two; registered with both values in `RocmAiterAllReduceFusionPass`. Validation on the same DSv3.2 TP4 MI355X rig: Chunks 1 + 2 still each show 0 standalone triton quants and 5 new `..._with_bf16_norm` fused ops (same as the single-variant version). Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT, 2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms, 2161 tok/s. The AITER variant adds zero runtime cost on this model because it doesn't fire; it's safety for models that route through AITER's group-quant path. Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
e4aab1e to
036383c
Compare
|
@maeehart Can you add a unit test to https://github.com/vllm-project/vllm/blob/main/tests/compile/passes/distributed/test_fusion_all_reduce.py similar to the existing ones for AR+RMS+FP8 static quant? |
|
Added It validates all three new
The new
Asserts Parametrized over both Validated on 2x MI355X (TP=2) with the PR 42864 docker image + ROCm/aiter#2823 wrapper patches: all 4 combos PASS in ~32 s. (CI green for these on AMD will require the build to ship a |
Picked up Tres's dtype fix (5da2f75) + parity for the 3 new patternsCherry-picked Validated on an MI355X (gfx950) box at TP=2 (the failing config from the reproduction) with Accuracy (lm_eval gsm8k 5-shot, --limit 200, Qwen3-Next-80B-A3B-Instruct-FP8)
Strict-match is bit-identical between fusion-on and fusion-off; flexible-extract is within half a stderr. The previously-broken TP=2 + AR+RMS-on path is now indistinguishable from the AR+RMS-off baseline. Perf (
|
| Conc | Total Tput (AR+RMS ON, with fix) | Total Tput (AR+RMS OFF) | ON vs OFF |
|---|---|---|---|
| 16 | 13773.0 | 13494.2 | +2.1 % |
| 32 | 21685.7 | 21711.1 | -0.1 % |
| 64 | 29987.3 | 30683.1 | -2.3 % |
The cast is weight.to(input.dtype) on a graph-input weight parameter, so inductor constant-folds it at compile time — no runtime cost. The ON-vs-OFF perf delta is in single-run noise.
I'm leaving the indexer-fan-out pattern in place (it was an earlier suspect before the dtype diagnosis came in). It's guarded by rms being returned as a pattern output and was not the proximate cause of the Qwen3-Next regression; on DSv3.2 it still earns its ~535 µs/step.
Per gemini-code-assist suggestion on vllm-project#42864: The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses `MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of the FP8 group-quant op. The original `WithIndexer` pattern only matched the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's `fused_qkv_a_proj` actually emits today (verified by FX dump). Empirically: replacing the direct triton-op call with `MatcherQuantFP8` takes the indexer fusion from 5 matches per chunk to 0 -- the matcher only registers one form per instance, and `match_rocm_aiter=True` picks the AITER form. So a single-variant pattern can't cover both. Solution: keep the Triton variant (which catches today's DSv3.2 sites) and additionally register an AITER-form variant via `MatcherQuantFP8` for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both lower to the same 4-output fused op + standalone indexer GEMM, so the replacement is identical. `use_triton_quant: bool` toggles between the two; registered with both values in `RocmAiterAllReduceFusionPass`. Validation on the same DSv3.2 TP4 MI355X rig: Chunks 1 + 2 still each show 0 standalone triton quants and 5 new `..._with_bf16_norm` fused ops (same as the single-variant version). Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT, 2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms, 2161 tok/s. The AITER variant adds zero runtime cost on this model because it doesn't fire; it's safety for models that route through AITER's group-quant path. Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
…erQuantFP8 Per @ProExpertProg's review on vllm-project#42864: ``MatcherQuantFP8`` now traces both ``QuantFP8.forward_hip`` and ``forward_native`` paths, so the indexer-fan-out pattern no longer needs to register two variants of itself (one matching ``triton_per_token_group_quant_fp8`` directly and one matching ``rocm_aiter_group_fp8_quant``). Drop the ``use_triton_quant`` constructor flag and the corresponding double-registration in ``RocmAiterAllReduceFusionPass``. The single ``MatcherQuantFP8(group_size=128, match_rocm_aiter=True)`` instance now matches both call sites; ``test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace`` already parametrizes over ``use_triton_quant`` on the synthetic model side, so unit-test coverage is unchanged. Co-authored-by: Cursor <cursoragent@cursor.com>
5da2f75 to
2e1c457
Compare
|
@ProExpertProg good call -- I rebased on New commit on top of the existing series (HEAD =
ValidationThe same
Identical within stderr -- the matcher consolidation does not change which patterns fire on Qwen3-Next-80B FP8.
|
|
Quick correction on the previous comment: those DSv3.2 numbers I alluded to ("remain those reported earlier in this PR's testing log") are the ones already in the PR description, not numbers I re-measured against the rebased + refactored branch. I should have been explicit -- I did not re-run DSv3.2 in this session. Reasoning for not re-running:
So how I'd frame the validation now:
|
DSv3.2 (and other FP8-blockwise models) end every transformer block with
all_reduce -> [fused_add_]rms_norm -> rocm_aiter_group_fp8_quant -> fp8_gemm
PR vllm-project#41825 fixed the quant-side half of this (`RocmAiterRMSNormQuantFusionPass`)
so a plain `rms_norm -> group_fp8_quant` pair is rewritten as
`rocm_aiter_rmsnorm_fp8_group_quant` (one HIP call). But that pass cannot fire
when the `rms_norm` consumer of the all_reduce has already been absorbed by
`RocmAiterAllReduceFusionPass`: at that point the FX graph just sees an opaque
`rocm_aiter_fused_allreduce_rmsnorm` producing bf16 and the standalone
`rocm_aiter_group_fp8_quant` consumer stays unfused. On a DSv3.2 TP4 decode
trace that leaves ~535us / step of `dynamic_per_group_scaled_quant` launches
(122 calls per step at ~4.4us each).
This change extends `RocmAiterAllReduceFusionPass` with two new patterns that
match the full `AR -> RMSNorm[+add] -> group_fp8_quant` chain and rewrite it
into a single `rocm_aiter_fused_allreduce_rmsnorm_quant_per_group` op backed
by AITER's `fused_ar_rms_per_group_quant` launcher (ROCm/aiter PR vllm-project#2823).
Mechanics:
- `vllm/_aiter_ops.py`: add the custom op binding plus the
`AiterCustomAllreduceProto` member, a fake impl returning the FP8 quant
tensor + per-group scale (`(M, hidden/group_size)` float32), and a feature-
probe `has_fused_allreduce_rmsnorm_quant_per_group` so callers can degrade
cleanly on older aiter builds.
- `vllm/compilation/passes/fusion/allreduce_rms_fusion.py`: two new
`VllmPatternReplacement` patterns (`AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`
and the `fused_add` sibling) wired into `RocmAiterAllReduceFusionPass`.
Both reuse `MatcherQuantFP8` to be insensitive to whether `quant_fp8` is
enabled as a custom op (same approach as PR vllm-project#41825). The quant variants
register before the non-quant variants so the matcher prefers them whenever
a downstream group quant exists; the existing AR+RMS-only patterns still
match for the AR sites that lack a trailing quant (e.g. final block).
- The pass keys off `rocm_aiter_ops.has_fused_allreduce_rmsnorm_quant_per_group()`
so an aiter build without PR vllm-project#2823 silently falls back to the existing
AR+RMS-only fusion (correct but slower).
This is the AR-side analogue of PR vllm-project#41825 and the ROCm port of the
flashinfer `AllReduceFusedRMSNormStaticQuantFP8Pattern` family that already
exists in the same file for the NVIDIA path.
Validation pending: needs DSv3.2 TP4 bf16-rms-bf16-input FP8-blockwise smoke
on MI355X to confirm pattern count and end-to-end serving accuracy parity.
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
DeepSeek V3.2 ends every transformer block but the first with
all_reduce -> fused_add_rms_norm -> {
per-token-group FP8 quant -> fp8 block-scaled gemm (fused_qkv_a_proj),
rocm_unquantized_gemm (bf16) (indexer wk_weights_proj),
}
The bf16 indexer GEMM is the blocker: the existing
`AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` only matches when the
RMSNorm output is consumed by a single group-quant node, so the trailing
quant kernel survives as a standalone `triton_per_token_group_quant_fp8`
launch (~535 us / decode step on DSv3.2 MI355X TP4).
This change closes that gap.
Mechanics:
- `vllm/_aiter_ops.py`: register
`rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm` which
wraps the existing AITER `fused_ar_rms_per_group_quant` launcher with
`emit_bf16=True`. The launcher already supports the four-output variant;
this just exposes it as a custom op (impl + fake + helper accessor)
alongside the three-output op added by the prior commit.
- `vllm/compilation/passes/fusion/allreduce_rms_fusion.py`: new
`AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`. Matches the
full `AR -> fused_add_rms_norm.default -> (triton_per_token_group_quant_fp8
+ rocm_unquantized_gemm)` subgraph from the actual DSv3.2 post-grad FX
dump (note: `.default`, not `.maybe_inplace`; the AR fusion pass runs
before `VllmIRInplaceFunctionalizationPass`).
The RMSNorm output is also used outside the matched subgraph (it is a
graph output that feeds the next compiled chunk as the residual carry),
so the pattern returns it as a fifth output and the replacement emits
the fused op's `bf16_norm` output in the same position. Pattern
matcher's substitution then rewires all external uses to the emitted
bf16 tensor automatically.
Registered before the existing AR+RMS+QUANT and AR+RMS-only patterns
so the larger subgraph is preferred when the indexer fan-out is
present.
Validation (DSv3.2 TP4 MI355X, vllm-openai-rocm:nightly-32b7177909):
FX dump (`__compiled_fn_1.post_grad.1.rocm_aiter_allreduce_fusion_pass.
after.1.py`):
standalone triton_per_token_group_quant_fp8: 4 -> 0
fallback rocm_aiter_fused_allreduce_rmsnorm.default: 3 -> 0
rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm: 0 -> 5
Same picture in chunk 2 (4 -> 0 stranded quants, +5 new fused ops).
Chunks 0 and 3 unchanged (no `fused_add + quant + indexer` site).
Serving perf (`vllm bench serve`, ISL=1000, OSL=100, 3-4 seeds each):
a1-baseline a1-patched (prior) this commit
mc=4 np=32 17.69 ms TPOT 17.93 ms TPOT 16.96 ms TPOT
mc=8 np=64 21.42 ms TPOT 20.60 ms TPOT 20.80 ms TPOT
mc=16 np=64 28.04 ms TPOT 28.81 ms TPOT 26.59 ms TPOT
mc=16 throughput: 4644 -> 4537 -> 4945 tok/s
i.e. -7.7% TPOT / +9.0% throughput at mc=16 vs the prior commit.
End-to-end output remains coherent (capital of France, simple arithmetic).
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
The two AR+RMSNorm+FP8 quant fusion commits introduced a handful of unnecessary parenthesized line wraps that `ruff format` collapses back into single-line form. No behavior change. Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Per gemini-code-assist suggestion on vllm-project#42864: The sibling `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern` uses `MatcherQuantFP8` and matches the `rocm_aiter_group_fp8_quant` form of the FP8 group-quant op. The original `WithIndexer` pattern only matched the `triton_per_token_group_quant_fp8` form, which is what DSv3.2's `fused_qkv_a_proj` actually emits today (verified by FX dump). Empirically: replacing the direct triton-op call with `MatcherQuantFP8` takes the indexer fusion from 5 matches per chunk to 0 -- the matcher only registers one form per instance, and `match_rocm_aiter=True` picks the AITER form. So a single-variant pattern can't cover both. Solution: keep the Triton variant (which catches today's DSv3.2 sites) and additionally register an AITER-form variant via `MatcherQuantFP8` for sites that route through `vllm.rocm_aiter_group_fp8_quant`. Both lower to the same 4-output fused op + standalone indexer GEMM, so the replacement is identical. `use_triton_quant: bool` toggles between the two; registered with both values in `RocmAiterAllReduceFusionPass`. Validation on the same DSv3.2 TP4 MI355X rig: Chunks 1 + 2 still each show 0 standalone triton quants and 5 new `..._with_bf16_norm` fused ops (same as the single-variant version). Bench at mc=4 np=32 (n=6 seeds): 17.19 +/- 0.36 ms TPOT, 2154 tok/s -- within noise of the single-variant 16.96 +/- 0.33 ms, 2161 tok/s. The AITER variant adds zero runtime cost on this model because it doesn't fire; it's safety for models that route through AITER's group-quant path. Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
- Trim AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern docstring to match the sibling AR+RMS+QUANT patterns' style: keep the why, drop the FX-form excerpt and the cross-segment caveat that belongs in a passes-level note. Two-variant rationale stays in the docstring. - Collapse the two near-identical register(...) calls in RocmAiterAllReduceFusionPass into a small loop over the boolean flag; delete the inline comment that duplicates the docstring. No behavior change. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…erns Adds ``test_rocm_aiter_all_reduce_rmsnorm_group_quant_fp8_fusion_pass_replace`` plus a ``TestAiterAllReduceRMSNormGroupQuantFP8Model`` to validate the three new ``VllmPatternReplacement`` classes this PR registers on ``RocmAiterAllReduceFusionPass``: * ``AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`` (no-residual) * ``AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern`` (with-residual, single ``rms`` consumer) * ``AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`` (with- residual, DSv3.2 indexer fan-out -- parametrized over both ``triton_per_token_group_quant_fp8`` and ``rocm_aiter_group_fp8_quant`` producers, matching the two registrations added in commit 4) The model uses four ``rms_norm`` sites that together exercise all three patterns (matched_count == 4), and chains the FP8 quant output back into the next AllReduce by manually dequantizing through the per-group scale so we do not depend on a real FP8 block-scaled GEMM kernel. Mirrors the structure and assertions of the sibling ``TestAllReduceRMSNormStaticQuantFP8Model`` test as Rohan138 requested in review. Validated on 2-GPU MI355X (TP=2, ``smci355-1b112-b13-19``) with the PR 42864 docker image (aiter v0.1.14-rc0 + ROCm/aiter#2823 wrapper patches via aiterPR3075): all 4 (``enable_rms_norm_custom_op`` x ``use_triton_quant``) combos PASS in ~32s. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
The AITER fused all-reduce kernels (fused_ar_rms and fused_ar_rms_per_group_quant) take a single opaque pointer for the norm weight and require its element dtype to equal the activation dtype. Models such as Qwen3-Next register the RMSNorm weight as float32 even when the activations are bf16, so the existing replacements -- both the AR+RMS-only patterns and the three new AR+RMS+per-group-FP8-quant patterns from this PR -- handed the kernel a mismatched (fp32 weight, bf16 input) pair that silently corrupted the normed activations. The downstream effect is poor GSM8K / MTP accuracy on Qwen3-Next TP2/TP4; TP1 was unaffected only because the AR+RMS fusion never fires in that case. Tres Popp identified the root cause and fixed the two existing patterns in `tpopp/aiter-ar-rmsnorm-weight-dtype`. This commit: * Picks up Tres's two-line `weight.to(input.dtype)` fix in `AiterAllreduceFusedRMSNormPattern` and `AiterAllreduceFusedAddRMSNormPattern`, and * Extends the same cast defensively to the three new AR+RMS+per-group- FP8-quant patterns introduced earlier in this PR (`AiterAllreduceFusedRMSNormGroupQuantFP8Pattern`, `AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern`, and `AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern`), all of which forward weight straight into the same `aiter_ar.fused_ar_rms_per_group_quant` kernel and thus inherit the same dtype constraint. When weight already matches input.dtype the `.to()` is a no-op (and constant-folds at compile time for graph-input parameters), so other models pay nothing. Diagnosis credit: Tres Popp, Nico Holmberg (reproduction + confirmation on Qwen3-Next-80B-A3B-Instruct-FP8 TP2). Co-authored-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
| allreduce = self.FUSED_AR_RMSNORM_OP( | ||
| input_=input, | ||
| residual=residual, | ||
| weight=weight, |
There was a problem hiding this comment.
Why do we need the type casts here? Edit: just saw Tres' commit adding the cast because aiter expects the rmsnorm weight and input to have the same dtype. Can you check if this cast creates an extra kernel, and if that kernel is fused into the preceding one?
There was a problem hiding this comment.
I believe this comment @tpopp is related to a different PR so I am going to mark it as resolve?
| return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm.default # noqa: E501 | ||
|
|
||
| @classmethod | ||
| def has_fused_allreduce_rmsnorm_quant_per_group(cls) -> bool: |
There was a problem hiding this comment.
Can you add a TODO comment to remove this once we bump to aiter 0.1.14 (should be next week)?
… probe Per Rohan138 on vllm-project#42864: mark the feature probe for removal once vLLM bumps its pinned AITER past v0.1.13.post1 (ROCm/aiter#2823 is expected to ship in v0.1.14). Signed-off-by: Frida Andersson <fanderss@amd.com>
dllehr-amd
left a comment
There was a problem hiding this comment.
Thanks @maeehart and @frida-andersson ! I appreciate the work on this one!
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com> Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com> Signed-off-by: divineearthly <divineearthly@gmail.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com>
…exer fan-out) (vllm-project#42864) Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Frida Andersson <fanderss@amd.com>
Summary
Two new patterns in
RocmAiterAllReduceFusionPassthat fuse theall_reduce -> RMSNorm[+add] -> per-group FP8 quant [-> bf16 indexer GEMM]chain into one AITER call, eliminating ~535 us / decode step of standalonetriton_per_token_group_quant_fp8launches on DeepSeek V3.2 MI355X TP4.This is the AR-side analogue of #41825 (which fixed the quant-only half) and the ROCm port of the flashinfer
AllReduceFusedRMSNormStaticQuantFP8Patternfamily already in this file for NVIDIA.Background
DSv3.2 (and other FP8-blockwise models) end every transformer block with
PR #41825 fixed the quant-side half (
RocmAiterRMSNormQuantFusionPass) so a plainrms_norm -> group_fp8_quantpair is rewritten asrocm_aiter_rmsnorm_fp8_group_quant. But that pass cannot fire onceRocmAiterAllReduceFusionPasshas already absorbed therms_normconsumer of the all_reduce: at that point the FX graph just sees an opaquerocm_aiter_fused_allreduce_rmsnormproducing bf16 and the standalonerocm_aiter_group_fp8_quantconsumer stays unfused.On DSv3.2 TP4 MI355X that leaves the standalone per-token-group quant launches at ~535 us / decode step (61 layers x ~8.8 us each).
What this PR does
Commit 1 (
[ROCm][Compile] Fuse trailing per-group FP8 quant into AITER AR+RMSNorm):Adds
rocm_aiter_fused_allreduce_rmsnorm_quant_per_groupcustom op invllm/_aiter_ops.py(wraps Add fused AR + RMSNorm + per-group FP8 quant: optional bf16 side-output ROCm/aiter#2823fused_ar_rms_per_group_quant) with impl, fake impl, andhas_fused_allreduce_rmsnorm_quant_per_group()feature probe.Adds two
VllmPatternReplacementpatterns:AiterAllreduceFusedRMSNormGroupQuantFP8Pattern:AR -> rms_norm -> group_fp8_quant(no residual, first layer).AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern:AR -> fused_add_rms_norm -> group_fp8_quant(with residual, layers 1-N).Both reuse
MatcherQuantFP8to be insensitive to whetherquant_fp8is enabled as a custom op (same approach as [ROCm][Perf] Fix RMSNorm+Quant fusion for gfx950 (non-fnuz) #41825). Registered before the existing no-quant patterns so the larger subgraph wins.Falls back cleanly to the existing AR+RMS-only fusion on AITER builds without Add fused AR + RMSNorm + per-group FP8 quant: optional bf16 side-output ROCm/aiter#2823.
Commit 2 (
[ROCm][Compile] Fuse AR+RMSNorm+FP8 quant + DSv3.2 indexer fan-out):The above only fires when the RMSNorm output has a single consumer. In DSv3.2 it has two: the FP8 group-quant (for
fused_qkv_a_proj) and a separate bf16rocm_unquantized_gemm(for the indexerwk_weights_proj). With the fan-out present, the pattern matcher falls back to the no-quant fusion and the quant kernel survives.This commit closes that gap:
Adds
rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm-- the same AITER launcher withemit_bf16=True, returning a 4-tuple (FP8 quant, residual, per-group scale, bf16 normed activations). The launcher already supported this; this just exposes it as a custom op.Adds
AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern. MatchesAR -> fused_add_rms_norm.default -> (per-group FP8 quant + rocm_unquantized_gemm)directly from the actual DSv3.2 post-grad FX dump (note:.default, not.maybe_inplace; the AR fusion pass runs beforeVllmIRInplaceFunctionalizationPass). The replacement reuses the fused op's emitted bf16 norm output for the indexer GEMM.The RMSNorm output is also a graph output (it feeds the next compiled chunk as the residual carry), so the pattern returns it as a fifth value and the replacement substitutes the fused op's bf16 output in the same position. The pattern matcher then rewires all external uses automatically.
Registered before the single-consumer quant pattern so the larger subgraph is preferred when an indexer GEMM is present.
Commit 3 (
[Format] ruff format AR+RMSNorm fusion files): mechanical -- collapses a few unnecessary parenthesized line wraps in the new code so the files passruff format --check. No behavior change.Commit 4 (
[Review] Also register AITER-form variant of indexer fan-out pattern): the indexer-fan-out pattern is now registered twice via ause_triton_quantflag. The Triton-quant form matches what DSv3.2'sfused_qkv_a_projactually emits at the fan-out site today (viaQuantFP8.forward_hip); theMatcherQuantFP8form matches sites where the producer routes throughrocm_aiter_group_fp8_quant. Without both, FX dumps showed the new pattern firing at only one variant of the call site. Per Gemini reviewer ask for consistency with the sibling AR+RMS+QUANT patterns.Commit 5 (
[Review] Tighten indexer fan-out pattern docs and registration): docstring/comment cleanup only -- trim the new pattern's docstring to match the sibling AR+RMS+QUANT patterns' style and replace the two near-identicalregister(...)calls with a small loop. No behavior change.Validation
DSv3.2 TP4 on MI355X,
vllm/vllm-openai-rocm:nightly-32b7177909, AITER with ROCm/aiter#2823 wrapper patches.FX dump diff at the dominant post-attn site (
__compiled_fn_1.post_grad.1.rocm_aiter_allreduce_fusion_pass.after.1.py):triton_per_token_group_quant_fp8rocm_aiter_fused_allreduce_rmsnorm.default..._quant_per_group_with_bf16_normrocm_unquantized_gemm.default(indexer)bf16_normfrom fused op)Same picture in chunk 2 (4 stranded quants -> 0, +5 new fused ops). Chunks 0 and 3 unchanged because they don't have the
fused_add + quant + indexersite.Serving perf (
vllm bench serve, ISL=1000, OSL=100, 3-4 seeds each). "main" = upstreammainwith no patterns from this PR; "patched, commit 1 only" = the AR+RMS+QUANT half (single-consumer pattern, no indexer fan-out); "patched, full PR" = all five commits.i.e. -7.7% TPOT, +9.0% throughput at mc=16 for the full PR vs main, and the indexer-fan-out commit (commit 2 onwards) is doing the lion's share of the work -- commit 1 alone shifts perf only marginally (and at mc=4/mc=16 not in a useful direction). End-to-end output remains coherent (capital of France, simple arithmetic chain).
Notes / scope
The fusion pass only fires when
rocm_aiter_ops.has_fused_allreduce_rmsnorm_quant_per_group()is true (i.e. AITER ships thefused_ar_rms_per_group_quantlauncher); older AITER builds fall back silently to the existing AR+RMS-only fusion.Chunk 0 of the DSv3.2 compiled graph still has 4 standalone Triton quants -- that's the first-layer (no-
fused_add) indexer fan-out site (~1/61 of the per-step savings). A siblingAiterAllreduceFusedRMSNormGroupQuantWithIndexerPattern(no-Add variant) is a natural follow-up but not in this PR.This PR depends on Add fused AR + RMSNorm + per-group FP8 quant: optional bf16 side-output ROCm/aiter#2823 being available for the Python wrappers; without it the new ops silently fall back at the
has_*probe.Test plan
Made with Cursor