Skip to content

[train] Improve JaxTrainer TPU multi-slice fault tolerance and reservation ergonomics#62893

Merged
matthewdeng merged 10 commits into
ray-project:masterfrom
liulehui:jaxtrainer
May 5, 2026
Merged

[train] Improve JaxTrainer TPU multi-slice fault tolerance and reservation ergonomics#62893
matthewdeng merged 10 commits into
ray-project:masterfrom
liulehui:jaxtrainer

Conversation

@liulehui

@liulehui liulehui commented Apr 23, 2026

Copy link
Copy Markdown
Contributor

Description

  1. Fix JAX hangs after preemption recovery due to stale MEGASCALE_* env vars.

When one slice in an N-slice topology is preempted, the underlying TPU provider may inject stale MEGASCALE_NUM_SLICES / MEGASCALE_SLICE_ID / MEGASCALE_COORDINATOR_ADDRESS env vars on the replacement pods (e.g. reporting MEGASCALE_NUM_SLICES=3 for a 2-slice job because terminating pods were still counted). jax.distributed.initialize() then hangs forever waiting for the third slice that never appears.

Ray Train already has the authoritative view of the live worker group, so this PR makes _JaxBackend._setup_jax_distributed_environment always override the four MEGASCALE_* keys with values computed from the current worker group, regardless of what the provider stamped onto the pod. This decouples Ray Train from the provider's view and turns a previously fatal preemption recovery into a clean restart.

  1. TPU SlicePlacementGroup reservation failures are not retried
    The CPU/GPU and TPU paths in WorkerGroup._create_placement_group handle "placement group can't be satisfied right now" inconsistently:
Path Timeout outcome
CPU/GPU pg_handle.wait(timeout) raises WorkerGroupStartupTimeoutError → retryable; controller goes SCHEDULING -> RESCHEDULING
TPU (before this PR) SlicePlacementGroup(...) blocks synchronously in its constructor; on timeout the catch-all wraps it in a plain ValueErrornot retryable → run errors out

So if the autoscaler is still bringing up a TPU slice when the 100s head reservation deadline elapses, the run fails immediately instead of retrying. Thus, we translates TimeoutError from the TPU head reservation into the standard WorkerGroupStartupTimeoutError so it flows through the existing retry machinery, translates other unexpected exceptions into WorkerGroupStartupFailedError matching the precedent set by the worker-actor startup path
(RayActorError -> WorkerGroupStartupFailedError).

Test

Tested on Anyscale platform, see example logs: https://gist.github.com/liulehui/0bfb32d1db4d317e1694290fe1290850

@liulehui liulehui changed the title [train][jax] Jaxtrainer Apr 23, 2026
@liulehui liulehui changed the title [train][jax] Jaxtrainer multislice ft Apr 23, 2026
@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Apr 23, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a configurable timeout for TPU slice reservations and ensures that JAX multi-slice environment variables are authoritatively overridden to prevent initialization hangs. It also includes a new test case to verify the environment variable overrides. Feedback was provided to move an inline import to the top level for better PEP 8 compliance.

Comment thread python/ray/_private/accelerators/tpu.py Outdated
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui changed the title [train][jax] lehui/Jaxtrainer multislice ft Apr 27, 2026
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@liulehui liulehui marked this pull request as ready for review April 27, 2026 22:06
@liulehui liulehui requested review from a team as code owners April 27, 2026 22:06
@liulehui liulehui requested a review from ryanaoleary April 27, 2026 22:07
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@ray-gardener ray-gardener Bot added the train Ray Train Related Issue label Apr 28, 2026
Comment thread python/ray/train/v2/tests/test_jax_trainer.py
Comment thread python/ray/train/v2/jax/config.py Outdated
env_vars = {}
if num_slices > 1:
slice_id = min(i // workers_per_slice, num_slices - 1)
env_vars = get_tpu_coordinator_env_vars(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We current don't pass the master port through so it's always defaulting to 8081 now. Previously it would remain whatever value the user/TPU webhook set. Should this now be:

env_vars = get_tpu_coordinator_env_vars(
    coordinator_address=master_addr,
    num_slices=num_slices,
    slice_id=slice_id,
    coordinator_port=str(master_port), # use the dynamic value from controller
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove from list, I think we will need a different port value given that this port is used by multislice DCN while the master port is for jax distributed coordinator.

@ryanaoleary

Copy link
Copy Markdown
Contributor

Left a couple comments related to how we handle MEGASCALE_PORT. I think using the coordinator calculated value should be fine since I believe it will guarantee a unique port. Everything else in the PR LGTM so will approve once the port comments are addressed.

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Comment thread python/ray/train/v2/jax/config.py

@ryanaoleary ryanaoleary left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!!

Signed-off-by: Lehui Liu <lehui@anyscale.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Reviewed by Cursor Bugbot for commit 6387910. Configure here.

liulehui added 3 commits May 4, 2026 12:54
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
@matthewdeng matthewdeng merged commit 2094b3e into ray-project:master May 5, 2026
6 checks passed
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
…ation ergonomics (ray-project#62893)

## Description
1. Fix JAX hangs after preemption recovery due to stale `MEGASCALE_*`
env vars.

When one slice in an N-slice topology is preempted, the underlying TPU
provider may inject stale `MEGASCALE_NUM_SLICES` / `MEGASCALE_SLICE_ID`
/ `MEGASCALE_COORDINATOR_ADDRESS` env vars on the replacement pods (e.g.
reporting `MEGASCALE_NUM_SLICES=3` for a 2-slice job because terminating
pods were still counted). `jax.distributed.initialize()` then hangs
forever waiting for the third slice that never appears.

Ray Train already has the authoritative view of the live worker group,
so this PR makes `_JaxBackend._setup_jax_distributed_environment`
**always** override the four `MEGASCALE_*` keys with values computed
from the current worker group, regardless of what the provider stamped
onto the pod. This decouples Ray Train from the provider's view and
turns a previously fatal preemption recovery into a clean restart.

2. TPU SlicePlacementGroup reservation failures are not retried
The CPU/GPU and TPU paths in `WorkerGroup._create_placement_group`
handle "placement group can't be satisfied right now" inconsistently:

| Path | Timeout outcome |
|---|---|
| CPU/GPU | `pg_handle.wait(timeout)` raises
`WorkerGroupStartupTimeoutError` → retryable; controller goes
`SCHEDULING -> RESCHEDULING` |
| TPU (before this PR) | `SlicePlacementGroup(...)` blocks synchronously
in its constructor; on timeout the catch-all wraps it in a plain
`ValueError` → **not retryable** → run errors out |

So if the autoscaler is still bringing up a TPU slice when the 100s head
reservation deadline elapses, the run fails immediately instead of
retrying. Thus, we translates `TimeoutError` from the TPU head
reservation into the standard `WorkerGroupStartupTimeoutError` so it
flows through the existing retry machinery, translates other unexpected
exceptions into `WorkerGroupStartupFailedError` matching the precedent
set by the worker-actor startup path
  (`RayActorError -> WorkerGroupStartupFailedError`). 

## Test

Tested on Anyscale platform, see example logs:
https://gist.github.com/liulehui/0bfb32d1db4d317e1694290fe1290850

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

4 participants