Skip to content

[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175

Merged
ywang96 merged 12 commits into
vllm-project:mainfrom
lucianommartins:lucianommartins/gemma4-fa4
Jun 10, 2026
Merged

[Core][Model] Gemma4: Unified FA4 for all layers + FlashAttention mm_prefix support#42175
ywang96 merged 12 commits into
vllm-project:mainfrom
lucianommartins:lucianommartins/gemma4-fa4

Conversation

@lucianommartins

@lucianommartins lucianommartins commented May 9, 2026

Copy link
Copy Markdown
Contributor

Purpose

Enable Flash Attention 4 (FA4) as the default attention backend for Gemma4 models on Hopper (SM90) and Blackwell (SM100+) GPUs, and add multimodal bidirectional attention (mm_prefix / PrefixLM) support to the FlashAttention backend.

Gemma4 uses heterogeneous head dimensions across layers: head_dim=256 for sliding-window attention and global_head_dim=512 for full attention. Previously, the Gemma4Config gate detected this mismatch and forced TRITON_ATTN as the only backend that could handle both sizes. With FA4 supporting head_dim up to 512, this restriction is no longer necessary on FA4-capable hardware.

Problem 1: Mixed FA3+FA4 penalty. When FLASH_ATTN was manually selected, the per-layer FA version dispatch assigned FA3 (the Hopper default) to sliding layers (head_dim=256) and FA4 to full-attention layers (head_dim=512). Benchmarking showed this mixed execution is ~8% slower than uniform FA4 for all layers, because FA4 has benchmarked tile configurations for head_dim<=256 that perform comparably to FA3.

Problem 2: mm_prefix not supported by FlashAttention. Gemma4 (and Gemma3, PaliGemma, Molmo2, etc.) use bidirectional attention for multimodal tokens (use_bidirectional_attention="vision"). The FlashAttentionBackend.supports_mm_prefix() returned False, forcing these models to Triton or FlexAttention. This blocked FA4 from being used at longer context lengths where the multimodal validation activates.

Changes

vllm/model_executor/models/config.py — Gemma4Config:

  • When FA4 is available and max_head_dim <= 512: set flash_attn_version=4 for all layers (uniform FA4, no mixed FA3+FA4)
  • When FA4 is not available: fall back to TRITON_ATTN (preserves existing safety behavior)
  • Respects user-explicit flash_attn_version override

vllm/v1/attention/backend.py — CommonAttentionMetadata:

  • Add mm_req_doc_ranges field so PrefixLM bidirectional ranges flow through the standard metadata build path instead of post-build monkey-patching

vllm/v1/attention/backends/flash_attn.py — FlashAttention backend:

  • Add supports_mm_prefix() -> True
  • Add mm_prefix_range_tensor and precomputed correction index fields to FlashAttentionMetadata
  • FlashAttentionMetadataBuilder.build(): compute mm_prefix tensors and precomputed indices from cm.mm_req_doc_ranges
  • Implement two-call decomposition for mm_prefix correction in _apply_mm_prefix_correction():
    1. Main causal flash_attn_varlen_func call (produces correct results for text tokens)
    2. Non-causal call restricted to mm_prefix ranges (corrects multimodal tokens)
    3. Merge via merge_attn_states using LSE rescaling
  • _precompute_mm_prefix_indices(): standalone function that computes correction indices on CPU during build() to avoid GPU-tensor .item() calls in the forward pass
  • Zero overhead for text-only batches (the correction is gated on mm_prefix_range_tensor is not None)

vllm/v1/attention/backends/triton_attn.py — Triton backend:

  • TritonAttentionMetadataBuilder.build(): compute mm_prefix_range and mm_prefix_range_tensor from cm.mm_req_doc_ranges

vllm/v1/attention/backends/flex_attention.py — FlexAttention backend:

  • FlexAttentionMetadataBuilder.build(): pass cm.mm_req_doc_ranges as mm_prefix_range to metadata constructor

vllm/v1/worker/gpu_model_runner.py:

  • Move req_doc_ranges computation before CommonAttentionMetadata construction and set via mm_req_doc_ranges field
  • Remove _set_mm_prefix_range_for_metadata() and _precompute_mm_prefix_indices() methods (replaced by native builder handling)

Impact on other models

  • PrefixLM models (Gemma3, PaliGemma, Molmo2, Moondream3, Bagel): previously restricted to Triton or FlexAttention for mm_prefix. These models can now use FLASH_ATTN as a backend candidate. The two-call decomposition is model-agnostic and mathematically correct for any PrefixLM model.
  • Non-PrefixLM models: no change. mm_req_doc_ranges is None, all builders skip mm_prefix handling.
  • Non-Gemma4 models: the Gemma4Config changes only activate for Gemma4ForCausalLM and Gemma4ForConditionalGeneration architectures via MODELS_CONFIG_MAP.

Test Plan

# Verify FA4 auto-selection for Gemma4
python -c "
from vllm.config import VllmConfig, ModelConfig
vc = VllmConfig(model_config=ModelConfig(model='google/gemma-4-31B-it', trust_remote_code=True, max_model_len=8192))
assert vc.attention_config.flash_attn_version == 4
assert vc.attention_config.backend is None  # auto-selects FLASH_ATTN
print('PASS: FA4 auto-selected')
"

# Gemma4 image with Triton (verify Triton builder mm_prefix path)
python -c "
from vllm import LLM, SamplingParams
from PIL import Image
from transformers import AutoTokenizer
model = 'google/gemma-4-E2B-it'
tokenizer = AutoTokenizer.from_pretrained(model)
llm = LLM(model=model, dtype='auto', max_model_len=2048, trust_remote_code=True,
          gpu_memory_utilization=0.95, attention_backend='TRITON_ATTN',
          limit_mm_per_prompt={'image': 1})
img = Image.open('images/cat.jpg').convert('RGB')
messages = [{'role': 'user', 'content': [{'type': 'image'},
            {'type': 'text', 'text': 'Describe this image briefly.'}]}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
out = llm.generate([{'prompt': prompt, 'multi_modal_data': {'image': img}}],
                   SamplingParams(max_tokens=64, temperature=0))
print('OUTPUT:', out[0].outputs[0].text[:200])
"

Test Result

Functional tests (H100 SXM 80GB)

Test Backend Result
Gemma4 E2B text (4 tests) FA4 PASS
Gemma4 E2B image (4 tests) FA4 PASS
Gemma4 E2B image Triton PASS

FA4 vs Triton — Throughput (tokens/sec, higher is better)

Benchmarked on 8x H100 SXM 80GB with TP=2 (31B) and TP=1 (E2B/E4B). All Gemma4 model sizes tested.

Benchmark E2B (~2B) E4B (~4B) 26B-A4B (MoE) 31B
Prefill (4K input) +40% +24% +25% +33%
Long context (8K) +41% +28% +27% +28%
Very long context (15-16K) +70% +47% +36%
Mixed (1K/1K) +5% +5% +1% +4%
High batch (256 in/out) +2% +5% +1% +5%

FA4 vs Triton — Latency (P50, lower is better)

Benchmark E2B E4B 26B-A4B 31B
Prefill TTFT -31% -22% -24% -22%
Long ctx decode (8K) -27% -19% -20% -16%
Decode b=1 +4% -1% +3% +1%
Decode b=8 +1% +1% +1% -1%

Key findings

  • FA4 wins all throughput scenarios across all 4 model sizes
  • Prefill/long-context improvement: +25-70% throughput, 22-31% faster TTFT
  • Short-context decode: neutral (~+/-2%) — weight-loading dominated per Amdahl's law
  • Uniform FA4 beats mixed FA3+FA4 by ~8% (kernel path uniformity)
  • mm_prefix two-call decomposition adds zero overhead for text-only requests; ~3% for multimodal batches (kernel launch overhead only)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the v1 label May 9, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for bidirectional attention within multimodal token ranges for FlashAttention, primarily to support Gemma4 models. It updates the configuration to handle heterogeneous head dimensions and implements a correction mechanism in the forward pass to merge causal and bidirectional attention results. However, the current implementation has significant performance bottlenecks due to synchronous CPU-GPU transfers and nested loops in the forward pass. Additionally, the decomposition logic for merging attention states is mathematically incorrect, leading to double-counting of tokens and incorrect KV range indexing.

Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
@mergify

mergify Bot commented May 9, 2026

Copy link
Copy Markdown
Contributor

Hi @lucianommartins, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
@lucianommartins

Copy link
Copy Markdown
Contributor Author

fyi @ywang96 @Isotr0py - a known limitation: mm_prefix attention overlap approximation

The two-call decomposition for PrefixLM bidirectional attention produces an approximation, not an exact result. The causal call covers KV positions [0, p] for a query at position p, while the non-causal correction covers [r_start, r_end]. For query tokens inside an mm_prefix range where p > r_start, these ranges overlap in [r_start, p]. The LSE-based merge (merge_attn_states) treats the two calls as independent partial results, which over-weights keys in the overlap region compared to the exact (causal OR mm_prefix) mask that the Triton backend computes in a single kernel via compute_kv_seq_mask.

Impact scope:

  • Text-only requests: none (mm_prefix never activates)
  • Multimodal decode: none (query tokens are always past the mm_prefix ranges)
  • Multimodal prefill: overlap grows linearly with position within the range — zero for the first image token, full range for the last. Affects intra-range attention distribution but not text token outputs

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from b12a2b7 to 042d5ab Compare May 9, 2026 17:03
@mergify

mergify Bot commented May 9, 2026

Copy link
Copy Markdown
Contributor
@mergify mergify Bot added the documentation Improvements or additions to documentation label May 9, 2026
@mergify

mergify Bot commented May 9, 2026

Copy link
Copy Markdown
Contributor

Hi @lucianommartins, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment on lines +6578 to +6582
def _precompute_mm_prefix_indices(
self,
metadata: "FlashAttentionMetadata", # type: ignore[name-defined]
req_doc_ranges: dict[int, list[tuple[int, int]]],
) -> None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too FA specific, which should not be placed at model runner.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed; im not super familiar with MM attention backends but seems like we should have common interface for triton and FA and move this code into there

@Isotr0py

Isotr0py commented May 9, 2026

Copy link
Copy Markdown
Member

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

BTW, which interface.py are you referring to?

Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
@lucianommartins

Copy link
Copy Markdown
Contributor Author

Correct fix: FA4's mask_mod callable supports the exact (causal OR mm_prefix) mask, but is currently blocked by interface.py:564-568 (mask_mod with aux_tensors is not yet supported for varlen sequences). Once this upstream limitation is resolved, the two-call decomposition can be replaced with a single mask_mod-based call.

BTW, which interface.py are you referring to?

hey @Isotr0py - it is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

@Isotr0py

Copy link
Copy Markdown
Member

It is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614

Perhaps we should update our FA fork? cc @LucasWilkinson

@MatthewBonanni

MatthewBonanni commented May 15, 2026

Copy link
Copy Markdown
Member

I was under the impression that FA4 does not yet support headdim 512 on blackwell: Dao-AILab/flash-attention#2456

SM90 support was landed with vllm-project/flash-attention#130, presumably that's what you used for these benchmarks?

edit: ah yeah, I see they were run on H100. You'll need to update this PR so that it only attempts FA4 for this head size on SM90

@lucianommartins

Copy link
Copy Markdown
Contributor Author

It is the vllm/vllm_flash_attn/cute/interface.py (the FA4 CuTE-DSL entry point), lines 564-568.

But I think the latest FA4 entrypoint should have removed this limitation: https://github.com/Dao-AILab/flash-attention/blob/ab66326aaa4fe3529fbc00f3156f3a762dd3141b/flash_attn/cute/interface.py#L588-L614

Perhaps we should update our FA fork? cc @LucasWilkinson

hey @Isotr0py @LucasWilkinson - have you folks had a chance to look into it?

@Isotr0py

Isotr0py commented May 20, 2026

Copy link
Copy Markdown
Member

@MatthewBonanni Can we have Dao-AILab/flash-attention#2224 in our FA fork? I think this PR needs this upstream sync to allow proper bidirectional attention mask computation for Gemma4.

@lucianommartins

Copy link
Copy Markdown
Contributor Author

have you had a chance to take a look at it @MatthewBonanni @LucasWilkinson ?

@mergify

mergify Bot commented May 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lucianommartins.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
@coopslarhette

Copy link
Copy Markdown

I've been kinda surprised by gemma 4 family's relatively poor perf on vllm (as compared to qwen, gpt-oss etc), even with tuning max_num_seq, max_num_batched_tokens to our use case. perhaps this is why? would be really helpful to get this in if so!

@lucianommartins

Copy link
Copy Markdown
Contributor Author

hey @Isotr0py @MatthewBonanni @LucasWilkinson - have you had a chance to ingest Dao-AILab/flash-attention#2224 into vLLM?

lucianommartins and others added 7 commits June 9, 2026 12:02
…ndling

- Add mm_req_doc_ranges field to CommonAttentionMetadata so PrefixLM
  bidirectional ranges flow through the standard metadata build path
  instead of being monkey-patched onto per-layer metadata post-build
- FlashAttentionMetadataBuilder.build(): compute mm_prefix_range_tensor
  and precomputed correction indices directly from cm.mm_req_doc_ranges
- TritonAttentionMetadataBuilder.build(): compute mm_prefix_range and
  mm_prefix_range_tensor directly from cm.mm_req_doc_ranges
- FlexAttentionMetadataBuilder.build(): pass cm.mm_req_doc_ranges as
  mm_prefix_range to FlexAttentionMetadata constructor
- Move req_doc_ranges computation before CommonAttentionMetadata
  construction in gpu_model_runner so it is available during build()
- Move _precompute_mm_prefix_indices from gpu_model_runner to
  flash_attn.py as a standalone function called by the builder
- Remove _set_mm_prefix_range_for_metadata and
  _precompute_mm_prefix_indices from GPUModelRunner (net -17 lines)
)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Move compute_mm_prefix_range_tensor from a static method on
TritonAttentionMetadata to a standalone utility in backends/utils.py
so both Triton and FlashAttention backends can import it directly
without cross-backend dependencies.

Also guard supports_mm_prefix behind is_fa_version_supported(4).

Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Use 64x64 tiles for compute_block_sparsity to match the smallest
FA4 kernel tile config across Gemma4 layer types (128×80 for
hdim=256, 64×64 for hdim=512 on SM90). The previous 128×128 was
conservative-safe but reduced block culling effectiveness.
)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
…uards

- Allow FA4 upgrade when backend is explicitly set to FLASH_ATTN,
  not only when auto-selected (config.py)
- Add use_mm_prefix param to supports_combination and validate that
  mm_prefix requires FA4 for the given head_size (flash_attn.py + base
  class + all MLA overrides)
- Add NotImplementedError guards for mask_mod/aux_tensors/block_sparse
  in FA2 and FA3 paths (flash_attn_interface.py)
)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
The compute_block_sparsity call was using hardcoded tile sizes that
don't match FA4's kernel tile selection on all architectures. On SM90
with head_dim > 128, the kernel picks tile_n < 128, causing a
ValueError in normalize_block_sparse_config.

Extract compute_block_sparsity into fa_utils.py with arch-dependent
tile selection that calls FA4's _tile_size_fwd_sm90 for SM90 and uses
the correct defaults for SM80/SM100/SM120.

Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
…parse

- Fix CuTE-DSL type mismatch in mm_prefix_mask_mod: use TensorSSA
  annotations and extract scalar via batch_idx[0] before tensor
  indexing, wrap lookups with scalar_to_ssa() for comparisons
  (matches upstream mask_mod_definitions.py pattern)
- Remove block_sparse_tensors from mm_prefix path: SM90's
  produce_block_sparse_loads passes raw TMA closures which are None
  when paged KV has page_size != tile_n (paged_kv_non_tma=True),
  causing load_K=None crash in block_sparse_utils.py; SM100 solved
  this via load_KV partials but SM90 has no equivalent fallback
- Remove compute_block_sparsity wrapper from flash_attn_interface.py
- Remove mm_prefix_block_sparse field from FlashAttentionMetadata
- Remove block_sparse_tensors param from flash_attn_varlen_func
  and associated FA2/FA3 guards
- mask_mod alone provides correct (causal OR bidirectional) masking;
  block_sparse is a skip-masked-blocks optimization that requires
  upstream flash-attention changes for paged KV on SM90
)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from c22b4e7 to 1822c24 Compare June 9, 2026 12:03
- Gate supports_mm_prefix() on attention_config.flash_attn_version == 4
  instead of is_fa_version_supported(4)
- flash_attn_version is only non-None when explicitly set by a model
  config (currently only Gemma4) or by the user
- Prevents other PrefixLM models (Gemma3, PaliGemma, Moondream3) from
  entering the FA4 mask_mod path, which is untested for them
- Those models continue using Triton for mm_prefix as before this PR
- Fixes CI failure: test_gemma3_mm_gguf hitting mask_mod on FA4

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
@lucianommartins lucianommartins force-pushed the lucianommartins/gemma4-fa4 branch from 1822c24 to 7817442 Compare June 9, 2026 12:05
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
@MatthewBonanni MatthewBonanni enabled auto-merge (squash) June 9, 2026 19:49
dougbtv added a commit to dougbtv/vllm that referenced this pull request Jun 9, 2026
Merge of PR vllm-project#42175 duplicated FA4 parameters that already existed from
the diffusion-staging merge. Remove the second occurrence.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ywang96 ywang96 disabled auto-merge June 10, 2026 00:36
@ywang96 ywang96 merged commit 6deb05e into vllm-project:main Jun 10, 2026
88 of 92 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Jun 10, 2026
@lucianommartins lucianommartins deleted the lucianommartins/gemma4-fa4 branch June 10, 2026 00:46
waqahmed-amd-fi pushed a commit to waqahmed-amd-fi/vllm that referenced this pull request Jun 10, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Waqar Ahmed <waqar.ahmed@amd.com>
Saddss pushed a commit to Saddss/vllm that referenced this pull request Jun 14, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
vivek8123 pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Jun 18, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
divineearthly pushed a commit to divineearthly/vllm that referenced this pull request Jun 19, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: divineearthly <divineearthly@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Jun 22, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
nkzhenhua pushed a commit to nkzhenhua/vllm that referenced this pull request Jun 24, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
@Davids048

Copy link
Copy Markdown

@lucianommartins Thanks for the great work!

When I ran Gemma 4 31B I noticed that it is still falling back to Triton attention backend. I think it is because fa_util.py is falling back to FA2 due to unsupported head size. And at flash_attn.py it is rejecting FA2, causing fallback to Triton.

Is this expected? Thanks!



Observed Log:

INFO 06-25 21:12:48 [api_utils.py:273] non-default args: {'trust_remote_code': True, 'kv_cache_dtype': 'fp8', 'max_model_len': 10240, 'distributed_executor_backend': 'uni', 'gpu_memory_utilization': 0.9, 'max_num_batched_tokens': 311808, 'max_num_seqs': 46, 'disable_log_stats': True, 'limit_mm_per_prompt': {'image': 20}, 'mm_processor_kwargs': {'max_soft_tokens': 280}, 'model': '/mnt/zfs/home/junda.su/models/RedHatAI/gemma-4-31B-it-FP8-block'}
INFO 06-25 21:12:48 [model.py:598] Resolved architecture: Gemma4ForConditionalGeneration
INFO 06-25 21:12:48 [model.py:1725] Using max model len 10240
INFO 06-25 21:12:49 [cache.py:280] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor
INFO 06-25 21:12:49 [scheduler.py:252] Chunked prefill is enabled with max_num_batched_tokens=311808.
INFO 06-25 21:12:49 [config.py:90] Gemma4 model has heterogeneous head dimensions (head_dim=256, global_head_dim=512). Using FA4 for all layers to avoid mixed FA3/FA4 penalty.
INFO 06-25 21:12:49 [vllm.py:1006] Asynchronous scheduling is enabled.
INFO 06-25 21:12:49 [kernel.py:276] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['native'], fused_add_rms_norm=['native'])
WARNING 06-25 21:12:49 [cuda.py:325] Forcing --disable_chunked_mm_input for models with multimodal-bidirectional attention.
INFO 06-25 21:12:49 [compilation.py:310] Enabled custom fusions: norm_quant, act_quant
(EngineCore pid=876902) INFO 06-25 21:13:32 [core.py:114] Initializing a V1 LLM engine (v0.23.1rc1.dev409+gd7ab9be77) with config: model='/mnt/zfs/home/junda.su/models/RedHatAI/gemma-4-31B-it-FP8-block', speculative_config=None, tokenizer='/mnt/zfs/home/junda.su/models/RedHatAI/gemma-4-31B-it-FP8-block', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=10240, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=False, quantization=compressed-tensors, quantization_config=None, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=fp8, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False, jit_monitor_mode='warn', jit_monitor_verbose=False), seed=0, served_model_name=/mnt/zfs/home/junda.su/models/RedHatAI/gemma-4-31B-it-FP8-block, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.VLLM_COMPILE: 3>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['+quant_fp8', 'none', '+quant_fp8'], 'ir_enable_torch_wrap': True, 'splitting_ops': ['vllm::unified_attention_with_output', 'vllm::unified_mla_attention_with_output', 'vllm::mamba_mixer2', 'vllm::mamba_mixer', 'vllm::short_conv', 'vllm::linear_attention', 'vllm::plamo2_mamba_mixer', 'vllm::qwen_gdn_attention_core', 'vllm::gdn_attention_core_xpu', 'vllm::olmo_hybrid_gdn_full_forward', 'vllm::kda_attention', 'vllm::sparse_attn_indexer', 'vllm::rocm_aiter_sparse_attn_indexer', 'vllm::deepseek_v4_attention', 'vllm::unified_kv_cache_update', 'vllm::unified_mla_kv_cache_update'], 'compile_mm_encoder': False, 'cudagraph_mm_encoder': False, 'encoder_cudagraph_token_budgets': [], 'encoder_cudagraph_max_vision_items_per_batch': 0, 'encoder_cudagraph_max_frames_per_batch': None, 'compile_sizes': [], 'compile_ranges_endpoints': [311808], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'size_asserts': False, 'alignment_asserts': False, 'scalar_asserts': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.FULL_AND_PIECEWISE: (2, 1)>, 'cudagraph_num_of_warmups': 1, 'cudagraph_capture_sizes': [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False, 'fuse_rope_kvcache_cat_mla': False, 'fuse_act_padding': False}, 'max_cudagraph_capture_size': 88, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': False, 'static_all_moe_layers': []}, kernel_config=KernelConfig(ir_op_priority=IrOpPriorityConfig(rms_norm=['native'], fused_add_rms_norm=['native']), enable_flashinfer_autotune=True, moe_backend='auto', linear_backend='auto')
(EngineCore pid=876902) INFO 06-25 21:13:40 [parallel_state.py:1588] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://10.114.120.123:39633 backend=nccl
(EngineCore pid=876902) INFO 06-25 21:13:40 [parallel_state.py:1923] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A, EPLB rank N/A
(EngineCore pid=876902) INFO 06-25 21:13:50 [topk_topp_sampler.py:55] Using FlashInfer for top-p & top-k sampling.
(EngineCore pid=876902) INFO 06-25 21:13:53 [gpu_model_runner.py:5160] Starting to load model /mnt/zfs/home/junda.su/models/RedHatAI/gemma-4-31B-it-FP8-block...
(EngineCore pid=876902) INFO 06-25 21:14:02 [vllm.py:1006] Asynchronous scheduling is enabled.
(EngineCore pid=876902) INFO 06-25 21:14:02 [kernel.py:276] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['native'], fused_add_rms_norm=['native'])
(EngineCore pid=876902) INFO 06-25 21:14:02 [compilation.py:310] Enabled custom fusions: norm_quant, act_quant
(EngineCore pid=876902) INFO 06-25 21:14:02 [__init__.py:563] Selected DeepGemmFp8BlockScaledMMKernel for CompressedTensorsW8A8Fp8
(EngineCore pid=876902) INFO 06-25 21:14:02 [deep_gemm.py:175] deep_gemm not found in site-packages, trying vendored vllm.third_party.deep_gemm
(EngineCore pid=876902) INFO 06-25 21:14:02 [deep_gemm.py:202] DeepGEMM PDL enabled on vllm.third_party.deep_gemm.
(EngineCore pid=876902) INFO 06-25 21:14:02 [deep_gemm.py:120] DeepGEMM E8M0 enabled on current platform.
>>> (EngineCore pid=876902) WARNING 06-25 21:14:03 [fa_utils.py:169] FA4 on Blackwell does not support head_size=256 due to TMEM capacity limits, defaulting to FA version 2.
>>> (EngineCore pid=876902) INFO 06-25 21:14:03 [cuda.py:483] Using TRITON_ATTN attention backend out of potential backends: ['TRITON_ATTN'].
>>> (EngineCore pid=876902) WARNING 06-25 21:14:03 [fa_utils.py:169] FA4 on Blackwell does not support head_size=512 due to TMEM capacity limits, defaulting to FA version 2.
@fikrikarim

Copy link
Copy Markdown

@Davids048 I ran into the same thing.

This PR body mentions Blackwell (SM100+), but AFAIK Flash Attention itself doesn't support sm_120 like RTX 5090 and RTX PRO 6000. See Dao-AILab/flash-attention#1987 and Dao-AILab/flash-attention#2634.

We might need to wait until that PR is merged and it's pulled by FA vLLM fork.

ohsono pushed a commit to ohsono/vllm that referenced this pull request Jul 3, 2026
…prefix support (vllm-project#42175)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation kv-connector nvidia ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

9 participants