[train] Improve JaxTrainer TPU multi-slice fault tolerance and reservation ergonomics#62893
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a configurable timeout for TPU slice reservations and ensures that JAX multi-slice environment variables are authoritatively overridden to prevent initialization hangs. It also includes a new test case to verify the environment variable overrides. Feedback was provided to move an inline import to the top level for better PEP 8 compliance.
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
| env_vars = {} | ||
| if num_slices > 1: | ||
| slice_id = min(i // workers_per_slice, num_slices - 1) | ||
| env_vars = get_tpu_coordinator_env_vars( |
There was a problem hiding this comment.
We current don't pass the master port through so it's always defaulting to 8081 now. Previously it would remain whatever value the user/TPU webhook set. Should this now be:
env_vars = get_tpu_coordinator_env_vars(
coordinator_address=master_addr,
num_slices=num_slices,
slice_id=slice_id,
coordinator_port=str(master_port), # use the dynamic value from controller
)
There was a problem hiding this comment.
remove from list, I think we will need a different port value given that this port is used by multislice DCN while the master port is for jax distributed coordinator.
|
Left a couple comments related to how we handle |
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Reviewed by Cursor Bugbot for commit 6387910. Configure here.
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
…ation ergonomics (ray-project#62893) ## Description 1. Fix JAX hangs after preemption recovery due to stale `MEGASCALE_*` env vars. When one slice in an N-slice topology is preempted, the underlying TPU provider may inject stale `MEGASCALE_NUM_SLICES` / `MEGASCALE_SLICE_ID` / `MEGASCALE_COORDINATOR_ADDRESS` env vars on the replacement pods (e.g. reporting `MEGASCALE_NUM_SLICES=3` for a 2-slice job because terminating pods were still counted). `jax.distributed.initialize()` then hangs forever waiting for the third slice that never appears. Ray Train already has the authoritative view of the live worker group, so this PR makes `_JaxBackend._setup_jax_distributed_environment` **always** override the four `MEGASCALE_*` keys with values computed from the current worker group, regardless of what the provider stamped onto the pod. This decouples Ray Train from the provider's view and turns a previously fatal preemption recovery into a clean restart. 2. TPU SlicePlacementGroup reservation failures are not retried The CPU/GPU and TPU paths in `WorkerGroup._create_placement_group` handle "placement group can't be satisfied right now" inconsistently: | Path | Timeout outcome | |---|---| | CPU/GPU | `pg_handle.wait(timeout)` raises `WorkerGroupStartupTimeoutError` → retryable; controller goes `SCHEDULING -> RESCHEDULING` | | TPU (before this PR) | `SlicePlacementGroup(...)` blocks synchronously in its constructor; on timeout the catch-all wraps it in a plain `ValueError` → **not retryable** → run errors out | So if the autoscaler is still bringing up a TPU slice when the 100s head reservation deadline elapses, the run fails immediately instead of retrying. Thus, we translates `TimeoutError` from the TPU head reservation into the standard `WorkerGroupStartupTimeoutError` so it flows through the existing retry machinery, translates other unexpected exceptions into `WorkerGroupStartupFailedError` matching the precedent set by the worker-actor startup path (`RayActorError -> WorkerGroupStartupFailedError`). ## Test Tested on Anyscale platform, see example logs: https://gist.github.com/liulehui/0bfb32d1db4d317e1694290fe1290850 --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>

Description
MEGASCALE_*env vars.When one slice in an N-slice topology is preempted, the underlying TPU provider may inject stale
MEGASCALE_NUM_SLICES/MEGASCALE_SLICE_ID/MEGASCALE_COORDINATOR_ADDRESSenv vars on the replacement pods (e.g. reportingMEGASCALE_NUM_SLICES=3for a 2-slice job because terminating pods were still counted).jax.distributed.initialize()then hangs forever waiting for the third slice that never appears.Ray Train already has the authoritative view of the live worker group, so this PR makes
_JaxBackend._setup_jax_distributed_environmentalways override the fourMEGASCALE_*keys with values computed from the current worker group, regardless of what the provider stamped onto the pod. This decouples Ray Train from the provider's view and turns a previously fatal preemption recovery into a clean restart.The CPU/GPU and TPU paths in
WorkerGroup._create_placement_grouphandle "placement group can't be satisfied right now" inconsistently:pg_handle.wait(timeout)raisesWorkerGroupStartupTimeoutError→ retryable; controller goesSCHEDULING -> RESCHEDULINGSlicePlacementGroup(...)blocks synchronously in its constructor; on timeout the catch-all wraps it in a plainValueError→ not retryable → run errors outSo if the autoscaler is still bringing up a TPU slice when the 100s head reservation deadline elapses, the run fails immediately instead of retrying. Thus, we translates
TimeoutErrorfrom the TPU head reservation into the standardWorkerGroupStartupTimeoutErrorso it flows through the existing retry machinery, translates other unexpected exceptions intoWorkerGroupStartupFailedErrormatching the precedent set by the worker-actor startup path(
RayActorError -> WorkerGroupStartupFailedError).Test
Tested on Anyscale platform, see example logs: https://gist.github.com/liulehui/0bfb32d1db4d317e1694290fe1290850