[train][torchft] Ray Train manages replica group restarts#61475
Conversation
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for replica group restarts, a key feature for fault tolerance with torchft. The changes are well-structured, introducing ExecutionGroup and ExecutionGroupCallback as base classes to share logic between WorkerGroup and the new ReplicaGroup concepts. The controller logic is updated to handle partial restarts of failing replica groups, and the WorkerGroup is enhanced with a replace_replica_group method. The refactoring is clean and the new functionality is supported by a comprehensive set of tests. I have one suggestion regarding the callback handling to make it more robust for custom callbacks.
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
There was a problem hiding this comment.
I think we should hold off on torchft-native elastic training in this PR. Consider the core and data dependencies below:
| Ray Data | Ray Core | |
|---|---|---|
| Fixed training | Easy: move data shard from worker x to its replacement | None |
| Elastic training | Hard: must reconfigure data shards | Easy option: 1 PlacementGroup per replica group (size 1 for DDP). This is an antipattern. Hard option: Resizable PlacementGroup |
Because it’s easy to get Ray Data + Ray Train + torchft to work for fixed training - and elastic training still works because we can just fall back to worker group restarts - I would suggest addressing this in a future PR.
Instead of having 1 WorkerGroup with many ReplicaGroups, should we just have multiple WorkerGroups? I prefer 1 WorkerGroup with many ReplicaGroups because:
- Single controller single layer: We want to simultaneously poll all the workers. Similarly, we want to have a single SyncActor across all the workers to enable a report barrier. Having multiple WorkerGroups could complicate this.
- WorkerGroup and ReplicaGroup have different callbacks. For example, WorkerGroup.on_start should set up one process group per replica group.
- 1 WorkerGroup also translates more naturally to the Ray Train dashboard - showing 1 WorkerGroup per replica group in the UI would add unnecessary detail and complexity.
- It makes sense to have ReplicaGroup as a first class concept because it is data parallel group and it is intuitive to shard 1 dataset across the data parallel groups of 1 workergroup.
|
Question: how will state transitions look? |
…plica groups but everything else is not Signed-off-by: Timothy Seah <tseah@anyscale.com>
|
Sanity check audit of places that use
|
justinvyu
left a comment
There was a problem hiding this comment.
have a few more things to look at and the tests, but here's a few comments for now
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
A few high level thoughts on the design. These are non-blocking but we can discuss offline
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
| # After sorting: node0/gpu0, node0/gpu1, node1/gpu0, node1/gpu1 | ||
| # Each worker is its own replica group, so local_rank=0, | ||
| # local_world_size=1, node_rank=0 for all. | ||
| [ | ||
| DistributedContext( |
There was a problem hiding this comment.
thanks for the test, makes the node assignment business clearer with an example
Signed-off-by: Timothy Seah <tseah@anyscale.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Reviewed by Cursor Bugbot for commit 126a549. Configure here.
…t#61475) # Summary This PR follows up on ray-project#61156 by handling torchft worker group failure recovery. Here are some of the design decisions: * `ReplicaGroup.shutdown` is similar to `WorkerGroup.shutdown` (they both shut down workers and clear state) but doesn't do some other stuff (e.g. callbacks and placementgroup cleanup). * `WorkerGroup.replace_replica_group` is similar to `WorkerGroup._start_impl` so I refactored their shared functionality accordingly. The main difference is that the former runs fewer callbacks. This PR also changes the semantics of get_world_rank and get_local_rank/get_node_rank: * get_world_rank/get_world_size still apply to all workers (across replica groups). This is because we often need one global rank 0 e.g. DDP rank 0 worker uploads checkpoint. Since we are doing DDP only, this is equivalent to get_replica_group_id. * get_local_rank/get_local_size/get_node_rank now apply to (replica group, node) pairs. One common use case for get_local_rank is to only download data to a node once - but since every replica group gets a different shard of data, every "local rank 0" needs to download data anyway. Also, creating a local_rank across all replica groups on the same node might not be feasible; if there is a node failure that results in a replica group getting scheduled on an existing node with another replica group, we would need to re-sort all the local ranks. Fortunately, every torchft replica group is effectively its own torchrun group (communication between torchrun groups is handled by the `Manager`), so this is consistent with their model. For this reason we don't need to worry about `CUDA_VISIBLE_DEVICES` either. * The order of workers in `WorkerGroupState.workers` is equivalent to `get_world_rank`. This invariant still holds. I also went through every single `WorkerGroupCallback` method and determined whether or not they are relevant for `ReplicaGroups`. In many cases, the behavior is the same for `WorkerGroups` and `ReplicaGroups`, so I also defined corresponding `ExecutionGroupCallback` methods that get called in both cases by default. `before_init_train_context`: Always the same between `WorkerGroup` and `ReplicaGroup` so I moved this behavior up to `ExecutionGroupCallback`. * `AcceleratorSetupCallback`: well handled and tested in this PR * `CheckpointManagerCallback` and `ValidationManagerCallback`: same behavior. Will test this more thoroughly in the followup `ray.train.report` PR. * `DatasetSetupCallback`: will test this more in the followup `data integration` PR. `before_worker_group_shutdown`: sometimes same (in which case the user can override `before_execution_group_shutdown`) sometimes different (in which case the user can override `before_worker_group_shutdown` or `before_replica_group_shutdown`). * `BackendSetupCallback`: implemented in `before_execution_group_shutdown`. Well handled and tested in this PR * `StateManagerCallback`: only implemented with `before_worker_group_shutdown` because replica groups are not reflected in train run state * `ReportCallbackHandler`: will be fixed in followup `ray.train.report` PR. The main idea is that `before_replica_group_shutdown` can also clear report states. `after_worker_group_start`: same as `before_worker_group_shutdown` * `BackendSetupCallback` and `WorkingDirectorySetupCallback`: implemented in `before_execution_group_shutdown`. Well handled and tested in this PR * `DatasetSetupCallback` and `ReportCallbackHandler` will be handled in the aforementioned future PR's * `StateManager` and `PlacementGroupCleanerCallback` only apply to worker groups. `after_worker_group_shutdown` and `after_worker_group_abort` are only used by `DatasetsCallback`. They should only apply to worker groups; when a replica group goes down, rather than shutting down any state, we simply send the state to the replacement worker. Of course, the aforementioned future data integration PR will test this better. All other `WorkerGroupCallback` methods are irrelevant: * `on_worker_group-start/on_worker_group_shutdown` are just for timing the worker group. * `before_worker_group_start` and `before_worker_group_abort` are only used by `StateManagerCallback` which is irrelevant as explained earlier. * `after_worker_group_training_start` is never used. * `after_worker_group_poll_status` is irrelevant because it operates on all the workers in the worker group, while replica groups are just thin wrappers around the workers. # Testing I'm open to more unit test suggestions. I basically tried to unit test different layers of the stack as follows: * `test_torch_trainer`: e2e test. I also verified that it works as expected. It's still disabled until I add torchft dependencies to the train CI. * `test_controller`: tests that we correctly decide when to do a replica group restart or a full worker group restart. Note that this also tests elastic training. * `test_worker_group`: tests that when we `replace_replica_group` we correctly update the relevant state (`WorkerGroupState`, replica groups, polling state). I added `mark.parametrize` to other unit tests to verify other behavior works as expected with both worker group and replica group restarts e.g. callbacks and worker initialization. Successfully ran a driver script in a workspace with simulated node failures (killed the raylet) and confirmed that it worked. Driver script: https://gist.github.com/TimothySeah/35aa96b81b2d98d77c23b10c0baa71c9 Logs Failure detected: https://gist.github.com/TimothySeah/923ce69987de98b774d8de214e6845ae Training stops due to unmet quorum ``` �[36m(LighthouseServerActor pid=24407)�[0m 2026-03-11T13:02:09.398 [INFO] [torchft::lighthouse] - Quorum status: New quorum not ready, only have 2 participants, need min_replicas 3 [2/2 participants healthy][2 heartbeating][shrink_only=false] �[33m(raylet)�[0m The node with node id: 0b015b238c8d9b950c92fe08a188001514a61a624cebdfe0e2aab715 and address: 10.0.82.255 and node name: 10.0.82.255 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload ``` replace_replica_group fails because node is unschedulable: https://gist.github.com/TimothySeah/12030fe7814a72e1b21fd75c25710635 Autoscaling completes and training continues: https://gist.github.com/TimothySeah/98b0776199ff18a97ce64246ab7894c9 Eventually we reach the finished state. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
…t#61475) # Summary This PR follows up on ray-project#61156 by handling torchft worker group failure recovery. Here are some of the design decisions: * `ReplicaGroup.shutdown` is similar to `WorkerGroup.shutdown` (they both shut down workers and clear state) but doesn't do some other stuff (e.g. callbacks and placementgroup cleanup). * `WorkerGroup.replace_replica_group` is similar to `WorkerGroup._start_impl` so I refactored their shared functionality accordingly. The main difference is that the former runs fewer callbacks. This PR also changes the semantics of get_world_rank and get_local_rank/get_node_rank: * get_world_rank/get_world_size still apply to all workers (across replica groups). This is because we often need one global rank 0 e.g. DDP rank 0 worker uploads checkpoint. Since we are doing DDP only, this is equivalent to get_replica_group_id. * get_local_rank/get_local_size/get_node_rank now apply to (replica group, node) pairs. One common use case for get_local_rank is to only download data to a node once - but since every replica group gets a different shard of data, every "local rank 0" needs to download data anyway. Also, creating a local_rank across all replica groups on the same node might not be feasible; if there is a node failure that results in a replica group getting scheduled on an existing node with another replica group, we would need to re-sort all the local ranks. Fortunately, every torchft replica group is effectively its own torchrun group (communication between torchrun groups is handled by the `Manager`), so this is consistent with their model. For this reason we don't need to worry about `CUDA_VISIBLE_DEVICES` either. * The order of workers in `WorkerGroupState.workers` is equivalent to `get_world_rank`. This invariant still holds. I also went through every single `WorkerGroupCallback` method and determined whether or not they are relevant for `ReplicaGroups`. In many cases, the behavior is the same for `WorkerGroups` and `ReplicaGroups`, so I also defined corresponding `ExecutionGroupCallback` methods that get called in both cases by default. `before_init_train_context`: Always the same between `WorkerGroup` and `ReplicaGroup` so I moved this behavior up to `ExecutionGroupCallback`. * `AcceleratorSetupCallback`: well handled and tested in this PR * `CheckpointManagerCallback` and `ValidationManagerCallback`: same behavior. Will test this more thoroughly in the followup `ray.train.report` PR. * `DatasetSetupCallback`: will test this more in the followup `data integration` PR. `before_worker_group_shutdown`: sometimes same (in which case the user can override `before_execution_group_shutdown`) sometimes different (in which case the user can override `before_worker_group_shutdown` or `before_replica_group_shutdown`). * `BackendSetupCallback`: implemented in `before_execution_group_shutdown`. Well handled and tested in this PR * `StateManagerCallback`: only implemented with `before_worker_group_shutdown` because replica groups are not reflected in train run state * `ReportCallbackHandler`: will be fixed in followup `ray.train.report` PR. The main idea is that `before_replica_group_shutdown` can also clear report states. `after_worker_group_start`: same as `before_worker_group_shutdown` * `BackendSetupCallback` and `WorkingDirectorySetupCallback`: implemented in `before_execution_group_shutdown`. Well handled and tested in this PR * `DatasetSetupCallback` and `ReportCallbackHandler` will be handled in the aforementioned future PR's * `StateManager` and `PlacementGroupCleanerCallback` only apply to worker groups. `after_worker_group_shutdown` and `after_worker_group_abort` are only used by `DatasetsCallback`. They should only apply to worker groups; when a replica group goes down, rather than shutting down any state, we simply send the state to the replacement worker. Of course, the aforementioned future data integration PR will test this better. All other `WorkerGroupCallback` methods are irrelevant: * `on_worker_group-start/on_worker_group_shutdown` are just for timing the worker group. * `before_worker_group_start` and `before_worker_group_abort` are only used by `StateManagerCallback` which is irrelevant as explained earlier. * `after_worker_group_training_start` is never used. * `after_worker_group_poll_status` is irrelevant because it operates on all the workers in the worker group, while replica groups are just thin wrappers around the workers. # Testing I'm open to more unit test suggestions. I basically tried to unit test different layers of the stack as follows: * `test_torch_trainer`: e2e test. I also verified that it works as expected. It's still disabled until I add torchft dependencies to the train CI. * `test_controller`: tests that we correctly decide when to do a replica group restart or a full worker group restart. Note that this also tests elastic training. * `test_worker_group`: tests that when we `replace_replica_group` we correctly update the relevant state (`WorkerGroupState`, replica groups, polling state). I added `mark.parametrize` to other unit tests to verify other behavior works as expected with both worker group and replica group restarts e.g. callbacks and worker initialization. Successfully ran a driver script in a workspace with simulated node failures (killed the raylet) and confirmed that it worked. Driver script: https://gist.github.com/TimothySeah/35aa96b81b2d98d77c23b10c0baa71c9 Logs Failure detected: https://gist.github.com/TimothySeah/923ce69987de98b774d8de214e6845ae Training stops due to unmet quorum ``` �[36m(LighthouseServerActor pid=24407)�[0m 2026-03-11T13:02:09.398 [INFO] [torchft::lighthouse] - Quorum status: New quorum not ready, only have 2 participants, need min_replicas 3 [2/2 participants healthy][2 heartbeating][shrink_only=false] �[33m(raylet)�[0m The node with node id: 0b015b238c8d9b950c92fe08a188001514a61a624cebdfe0e2aab715 and address: 10.0.82.255 and node name: 10.0.82.255 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload ``` replace_replica_group fails because node is unschedulable: https://gist.github.com/TimothySeah/12030fe7814a72e1b21fd75c25710635 Autoscaling completes and training continues: https://gist.github.com/TimothySeah/98b0776199ff18a97ce64246ab7894c9 Eventually we reach the finished state. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>

Summary
This PR follows up on #61156 by handling torchft worker group failure recovery.
Here are some of the design decisions:
ReplicaGroup.shutdownis similar toWorkerGroup.shutdown(they both shut down workers and clear state) but doesn't do some other stuff (e.g. callbacks and placementgroup cleanup).WorkerGroup.replace_replica_groupis similar toWorkerGroup._start_implso I refactored their shared functionality accordingly. The main difference is that the former runs fewer callbacks.This PR also changes the semantics of get_world_rank and get_local_rank/get_node_rank:
Manager), so this is consistent with their model. For this reason we don't need to worry aboutCUDA_VISIBLE_DEVICESeither.WorkerGroupState.workersis equivalent toget_world_rank. This invariant still holds.I also went through every single
WorkerGroupCallbackmethod and determined whether or not they are relevant forReplicaGroups. In many cases, the behavior is the same forWorkerGroupsandReplicaGroups, so I also defined correspondingExecutionGroupCallbackmethods that get called in both cases by default.before_init_train_context: Always the same betweenWorkerGroupandReplicaGroupso I moved this behavior up toExecutionGroupCallback.AcceleratorSetupCallback: well handled and tested in this PRCheckpointManagerCallbackandValidationManagerCallback: same behavior. Will test this more thoroughly in the followupray.train.reportPR.DatasetSetupCallback: will test this more in the followupdata integrationPR.before_worker_group_shutdown: sometimes same (in which case the user can overridebefore_execution_group_shutdown) sometimes different (in which case the user can overridebefore_worker_group_shutdownorbefore_replica_group_shutdown).BackendSetupCallback: implemented inbefore_execution_group_shutdown. Well handled and tested in this PRStateManagerCallback: only implemented withbefore_worker_group_shutdownbecause replica groups are not reflected in train run stateReportCallbackHandler: will be fixed in followupray.train.reportPR. The main idea is thatbefore_replica_group_shutdowncan also clear report states.after_worker_group_start: same asbefore_worker_group_shutdownBackendSetupCallbackandWorkingDirectorySetupCallback: implemented inbefore_execution_group_shutdown. Well handled and tested in this PRDatasetSetupCallbackandReportCallbackHandlerwill be handled in the aforementioned future PR'sStateManagerandPlacementGroupCleanerCallbackonly apply to worker groups.after_worker_group_shutdownandafter_worker_group_abortare only used byDatasetsCallback. They should only apply to worker groups; when a replica group goes down, rather than shutting down any state, we simply send the state to the replacement worker. Of course, the aforementioned future data integration PR will test this better.All other
WorkerGroupCallbackmethods are irrelevant:on_worker_group-start/on_worker_group_shutdownare just for timing the worker group.before_worker_group_startandbefore_worker_group_abortare only used byStateManagerCallbackwhich is irrelevant as explained earlier.after_worker_group_training_startis never used.after_worker_group_poll_statusis irrelevant because it operates on all the workers in the worker group, while replica groups are just thin wrappers around the workers.Testing
I'm open to more unit test suggestions. I basically tried to unit test different layers of the stack as follows:
test_torch_trainer: e2e test. I also verified that it works as expected. It's still disabled until I add torchft dependencies to the train CI.test_controller: tests that we correctly decide when to do a replica group restart or a full worker group restart. Note that this also tests elastic training.test_worker_group: tests that when wereplace_replica_groupwe correctly update the relevant state (WorkerGroupState, replica groups, polling state). I addedmark.parametrizeto other unit tests to verify other behavior works as expected with both worker group and replica group restarts e.g. callbacks and worker initialization.Successfully ran a driver script in a workspace with simulated node failures (killed the raylet) and confirmed that it worked.
Driver script: https://gist.github.com/TimothySeah/35aa96b81b2d98d77c23b10c0baa71c9
Logs
Failure detected: https://gist.github.com/TimothySeah/923ce69987de98b774d8de214e6845ae
Training stops due to unmet quorum
replace_replica_group fails because node is unschedulable: https://gist.github.com/TimothySeah/12030fe7814a72e1b21fd75c25710635
Autoscaling completes and training continues: https://gist.github.com/TimothySeah/98b0776199ff18a97ce64246ab7894c9
Eventually we reach the finished state.