[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for bidirectional attention within multimodal token ranges for FlashAttention, primarily to support Gemma4 models. It updates the configuration to handle heterogeneous head dimensions and implements a correction mechanism in the forward pass to merge causal and bidirectional attention results. However, the current implementation has significant performance bottlenecks due to synchronous CPU-GPU transfers and nested loops in the forward pass. Additionally, the decomposition logic for merging attention states is mathematically incorrect, leading to double-counting of tokens and incorrect KV range indexing.
|
Hi @lucianommartins, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
fyi @ywang96 @Isotr0py - a known limitation: mm_prefix attention overlap approximation The two-call decomposition for PrefixLM bidirectional attention produces an approximation, not an exact result. The causal call covers KV positions [0, p] for a query at position p, while the non-causal correction covers [r_start, r_end]. For query tokens inside an mm_prefix range where p > r_start, these ranges overlap in [r_start, p]. The LSE-based merge ( Impact scope:
Correct fix: FA4's |
b12a2b7 to
042d5ab
Compare
|
Documentation preview: https://vllm--42175.org.readthedocs.build/en/42175/ |
|
Hi @lucianommartins, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| def _precompute_mm_prefix_indices( | ||
| self, | ||
| metadata: "FlashAttentionMetadata", # type: ignore[name-defined] | ||
| req_doc_ranges: dict[int, list[tuple[int, int]]], | ||
| ) -> None: |
There was a problem hiding this comment.
This is too FA specific, which should not be placed at model runner.
There was a problem hiding this comment.
agreed; im not super familiar with MM attention backends but seems like we should have common interface for triton and FA and move this code into there
BTW, which |
hey @Isotr0py - it is the |
But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614 Perhaps we should update our FA fork? cc @LucasWilkinson |
|
I was under the impression that FA4 does not yet support headdim 512 on blackwell: Dao-AILab/flash-attention#2456 SM90 support was landed with vllm-project/flash-attention#130, presumably that's what you used for these benchmarks? edit: ah yeah, I see they were run on H100. You'll need to update this PR so that it only attempts FA4 for this head size on SM90 |
hey @Isotr0py @LucasWilkinson - have you folks had a chance to look into it? |
|
@MatthewBonanni Can we have Dao-AILab/flash-attention#2224 in our FA fork? I think this PR needs this upstream sync to allow proper bidirectional attention mask computation for Gemma4. |
|
have you had a chance to take a look at it @MatthewBonanni @LucasWilkinson ? |
|
This pull request has merge conflicts that must be resolved before it can be |
|
I've been kinda surprised by gemma 4 family's relatively poor perf on vllm (as compared to qwen, gpt-oss etc), even with tuning |
|
hey @Isotr0py @MatthewBonanni @LucasWilkinson - have you had a chance to ingest Dao-AILab/flash-attention#2224 into vLLM? |
042d5ab to
ee7907c
Compare
…ndling - Add mm_req_doc_ranges field to CommonAttentionMetadata so PrefixLM bidirectional ranges flow through the standard metadata build path instead of being monkey-patched onto per-layer metadata post-build - FlashAttentionMetadataBuilder.build(): compute mm_prefix_range_tensor and precomputed correction indices directly from cm.mm_req_doc_ranges - TritonAttentionMetadataBuilder.build(): compute mm_prefix_range and mm_prefix_range_tensor directly from cm.mm_req_doc_ranges - FlexAttentionMetadataBuilder.build(): pass cm.mm_req_doc_ranges as mm_prefix_range to FlexAttentionMetadata constructor - Move req_doc_ranges computation before CommonAttentionMetadata construction in gpu_model_runner so it is available during build() - Move _precompute_mm_prefix_indices from gpu_model_runner to flash_attn.py as a standalone function called by the builder - Remove _set_mm_prefix_range_for_metadata and _precompute_mm_prefix_indices from GPUModelRunner (net -17 lines) ) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Move compute_mm_prefix_range_tensor from a static method on TritonAttentionMetadata to a standalone utility in backends/utils.py so both Triton and FlashAttention backends can import it directly without cross-backend dependencies. Also guard supports_mm_prefix behind is_fa_version_supported(4). Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Use 64x64 tiles for compute_block_sparsity to match the smallest FA4 kernel tile config across Gemma4 layer types (128×80 for hdim=256, 64×64 for hdim=512 on SM90). The previous 128×128 was conservative-safe but reduced block culling effectiveness. ) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
…uards - Allow FA4 upgrade when backend is explicitly set to FLASH_ATTN, not only when auto-selected (config.py) - Add use_mm_prefix param to supports_combination and validate that mm_prefix requires FA4 for the given head_size (flash_attn.py + base class + all MLA overrides) - Add NotImplementedError guards for mask_mod/aux_tensors/block_sparse in FA2 and FA3 paths (flash_attn_interface.py) ) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
The compute_block_sparsity call was using hardcoded tile sizes that don't match FA4's kernel tile selection on all architectures. On SM90 with head_dim > 128, the kernel picks tile_n < 128, causing a ValueError in normalize_block_sparse_config. Extract compute_block_sparsity into fa_utils.py with arch-dependent tile selection that calls FA4's _tile_size_fwd_sm90 for SM90 and uses the correct defaults for SM80/SM100/SM120. Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
…parse - Fix CuTE-DSL type mismatch in mm_prefix_mask_mod: use TensorSSA annotations and extract scalar via batch_idx[0] before tensor indexing, wrap lookups with scalar_to_ssa() for comparisons (matches upstream mask_mod_definitions.py pattern) - Remove block_sparse_tensors from mm_prefix path: SM90's produce_block_sparse_loads passes raw TMA closures which are None when paged KV has page_size != tile_n (paged_kv_non_tma=True), causing load_K=None crash in block_sparse_utils.py; SM100 solved this via load_KV partials but SM90 has no equivalent fallback - Remove compute_block_sparsity wrapper from flash_attn_interface.py - Remove mm_prefix_block_sparse field from FlashAttentionMetadata - Remove block_sparse_tensors param from flash_attn_varlen_func and associated FA2/FA3 guards - mask_mod alone provides correct (causal OR bidirectional) masking; block_sparse is a skip-masked-blocks optimization that requires upstream flash-attention changes for paged KV on SM90 ) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
c22b4e7 to
1822c24
Compare
- Gate supports_mm_prefix() on attention_config.flash_attn_version == 4 instead of is_fa_version_supported(4) - flash_attn_version is only non-None when explicitly set by a model config (currently only Gemma4) or by the user - Prevents other PrefixLM models (Gemma3, PaliGemma, Moondream3) from entering the FA4 mask_mod path, which is untested for them - Those models continue using Triton for mm_prefix as before this PR - Fixes CI failure: test_gemma3_mm_gguf hitting mask_mod on FA4 Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
1822c24 to
7817442
Compare
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Merge of PR vllm-project#42175 duplicated FA4 parameters that already existed from the diffusion-staging merge. Remove the second occurrence. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: divineearthly <divineearthly@gmail.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
|
@lucianommartins Thanks for the great work! When I ran Gemma 4 31B I noticed that it is still falling back to Triton attention backend. I think it is because fa_util.py is falling back to FA2 due to unsupported head size. And at flash_attn.py it is rejecting FA2, causing fallback to Triton. Is this expected? Thanks! Observed Log: |
|
@Davids048 I ran into the same thing. This PR body mentions We might need to wait until that PR is merged and it's pulled by FA vLLM fork. |
…prefix support (vllm-project#42175) Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Purpose
Enable Flash Attention 4 (FA4) as the default attention backend for Gemma4 models on Hopper (SM90) and Blackwell (SM100+) GPUs, and add multimodal bidirectional attention (mm_prefix / PrefixLM) support to the FlashAttention backend.
Gemma4 uses heterogeneous head dimensions across layers:
head_dim=256for sliding-window attention andglobal_head_dim=512for full attention. Previously, theGemma4Configgate detected this mismatch and forcedTRITON_ATTNas the only backend that could handle both sizes. With FA4 supportinghead_dimup to 512, this restriction is no longer necessary on FA4-capable hardware.Problem 1: Mixed FA3+FA4 penalty. When FLASH_ATTN was manually selected, the per-layer FA version dispatch assigned FA3 (the Hopper default) to sliding layers (
head_dim=256) and FA4 to full-attention layers (head_dim=512). Benchmarking showed this mixed execution is ~8% slower than uniform FA4 for all layers, because FA4 has benchmarked tile configurations forhead_dim<=256that perform comparably to FA3.Problem 2: mm_prefix not supported by FlashAttention. Gemma4 (and Gemma3, PaliGemma, Molmo2, etc.) use bidirectional attention for multimodal tokens (
use_bidirectional_attention="vision"). TheFlashAttentionBackend.supports_mm_prefix()returnedFalse, forcing these models to Triton or FlexAttention. This blocked FA4 from being used at longer context lengths where the multimodal validation activates.Changes
vllm/model_executor/models/config.py— Gemma4Config:max_head_dim <= 512: setflash_attn_version=4for all layers (uniform FA4, no mixed FA3+FA4)TRITON_ATTN(preserves existing safety behavior)flash_attn_versionoverridevllm/v1/attention/backend.py— CommonAttentionMetadata:mm_req_doc_rangesfield so PrefixLM bidirectional ranges flow through the standard metadata build path instead of post-build monkey-patchingvllm/v1/attention/backends/flash_attn.py— FlashAttention backend:supports_mm_prefix() -> Truemm_prefix_range_tensorand precomputed correction index fields toFlashAttentionMetadataFlashAttentionMetadataBuilder.build(): compute mm_prefix tensors and precomputed indices fromcm.mm_req_doc_ranges_apply_mm_prefix_correction():flash_attn_varlen_funccall (produces correct results for text tokens)merge_attn_statesusing LSE rescaling_precompute_mm_prefix_indices(): standalone function that computes correction indices on CPU during build() to avoid GPU-tensor.item()calls in the forward passmm_prefix_range_tensor is not None)vllm/v1/attention/backends/triton_attn.py— Triton backend:TritonAttentionMetadataBuilder.build(): computemm_prefix_rangeandmm_prefix_range_tensorfromcm.mm_req_doc_rangesvllm/v1/attention/backends/flex_attention.py— FlexAttention backend:FlexAttentionMetadataBuilder.build(): passcm.mm_req_doc_rangesasmm_prefix_rangeto metadata constructorvllm/v1/worker/gpu_model_runner.py:req_doc_rangescomputation beforeCommonAttentionMetadataconstruction and set viamm_req_doc_rangesfield_set_mm_prefix_range_for_metadata()and_precompute_mm_prefix_indices()methods (replaced by native builder handling)Impact on other models
mm_req_doc_rangesisNone, all builders skip mm_prefix handling.Gemma4Configchanges only activate forGemma4ForCausalLMandGemma4ForConditionalGenerationarchitectures viaMODELS_CONFIG_MAP.Test Plan
Test Result
Functional tests (H100 SXM 80GB)
FA4 vs Triton — Throughput (tokens/sec, higher is better)
Benchmarked on 8x H100 SXM 80GB with TP=2 (31B) and TP=1 (E2B/E4B). All Gemma4 model sizes tested.
FA4 vs Triton — Latency (P50, lower is better)
Key findings
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.