Skip to content

feat: experimental two-phase (head-chunked) Ulysses all-to-all#428

Open
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/ulysses-two-phase
Open

feat: experimental two-phase (head-chunked) Ulysses all-to-all#428
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/ulysses-two-phase

Conversation

@csgoogle

@csgoogle csgoogle commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Add an opt-in ULYSSES_ATTENTION_CHUNKS env var to split the Ulysses all-to-all into per-head-group passes, so XLA's async-collective scheduler can overlap one group's attention compute with the next group's all-to-all. Defaults to 1 (current single-shot path, no behavior change). Numerically identical to single-shot since heads are independent.

Notes:

  • Requires async-collective LIBTPU flags to actually overlap.
  • Gain is largest when all-to-all is a meaningful fraction of attention time (high context-parallelism / shorter sequences); at WAN 2.2 720p (seq~75600) it is compute-bound so the win is small (~3% in microbench), but for seqlen ~24k we observe ~10% gains
@google-cla

google-cla Bot commented Jun 24, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@csgoogle csgoogle force-pushed the sagarchapara/ulysses-two-phase branch from 7240f50 to 0d936f8 Compare June 24, 2026 13:59
@csgoogle csgoogle requested a review from Perseus14 June 24, 2026 14:05
# math is identical to the single-shot path (heads are independent); requires
# async-collective LIBTPU flags to actually overlap, and the per-chunk head
# count must still be shardable across the context axis.
num_chunks = int(os.environ.get("ULYSSES_ATTENTION_CHUNKS", "1"))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to config file to be used for any ulysses type kernel

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

f"got heads={num_heads} and context_shards={num_shards}."
)

# EXPERIMENTAL: split the all-to-all into `num_chunks` head-groups so XLA's

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work on ulysses + ring as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, updated the code.

@Perseus14 Perseus14 requested a review from eltsai June 24, 2026 20:14
@csgoogle csgoogle force-pushed the sagarchapara/ulysses-two-phase branch from 0d936f8 to 14eb661 Compare June 30, 2026 19:56
Add a ulysses_attention_chunks attention config to split the Ulysses all-to-all into head-group passes. The chunked path lets XLA overlap all-to-all collectives with head-parallel local attention compute while preserving the existing single-shot path by default.

Apply the same chunking to plain Ulysses and Ulysses+Ring, and allow the final chunk to carry the remainder when the requested chunk count does not divide the Ulysses head groups evenly.

Add mocked attention tests for numerical and layout equivalence across chunk counts.
@csgoogle csgoogle force-pushed the sagarchapara/ulysses-two-phase branch from 14eb661 to e09195e Compare June 30, 2026 20:02
@csgoogle csgoogle marked this pull request as ready for review July 1, 2026 05:52
@csgoogle csgoogle requested a review from entrpn as a code owner July 1, 2026 05:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants