[Serve] Optimize stop_replicas() to avoid pop-all/re-add cycle#60832
Conversation
Signed-off-by: abrar <abrar@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request optimizes stop_replicas to avoid a costly cycle of popping and re-adding all replicas when only a few need to be stopped. This is achieved by introducing a remove(replica_id) method in ReplicaStateContainer and using it in a loop within stop_replicas. The benchmarks clearly show significant performance improvements for stopping a small number of replicas.
My main concern, detailed in a specific comment, is a potential performance regression when stopping a large number of replicas, due to the O(k*N) complexity of the new approach. I've suggested an alternative that would be efficient for all cases.
| replica_ids = set(replica_ids) | ||
| removed = [] | ||
| remaining_to_find = len(replica_ids) | ||
| for state in ALL_REPLICA_STATES: | ||
| if remaining_to_find == 0: | ||
| break | ||
| found_any = False | ||
| remaining = [] | ||
| for replica in self._replicas[state]: | ||
| if remaining_to_find > 0 and replica.replica_id in replica_ids: | ||
| removed.append(replica) | ||
| remaining_to_find -= 1 | ||
| found_any = True | ||
| else: | ||
| remaining.append(replica) | ||
| if found_any: | ||
| self._replicas[state] = remaining | ||
| return removed |
There was a problem hiding this comment.
this would still require us to iterate over the states first, and then each replica, and then create state-wise individual list again.
would it be beneficial if we modify this _replicas to be Dict [ReplicaState, Dict[ReplicaId, DeploymentReplica]]? that we can just find the replica id in O(1) and then remove it. it might add some overhead as well, can we check how much it will be?
There was a problem hiding this comment.
Maybe there is optimization to be had there, but I prefer taking baby steps towards that. Will take that as a follow-up. Changing the data structure of _replicas also requires rewriting all other functions in this class.
…roject#60832) `stop_replicas()` pops **every** replica across all 7 states, checks set membership, and re-adds the vast majority back. Each re-add triggers `update_actor_details()` which rebuilds a `ReplicaDetails` pydantic object. When stopping 2 out of 4096 replicas, 4094 replicas get needlessly popped, rebuilt, and re-added. ### Fix Add a `remove(replica_ids)` method to `ReplicaStateContainer` that performs a single O(N) pass with O(1) set lookups. Non-matching replicas stay in place — no re-add, no update_state call. Early-exits once all targets are found, and only rebuilds the list for states where a match was found. ### Benchmark results <details> <summary> Benchmark script - AI </summary> ```python """Micro-benchmark: stop_replicas() pop-all vs selective-remove. Measures latency and peak memory when stopping a small fraction of replicas from a ReplicaStateContainer, comparing the old approach (pop all + re-add) against the new approach (selective remove by ID). Usage: python bench_stop_replicas.py """ import gc import math import random import statistics import time import tracemalloc from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Set # --------------------------------------------------------------------------- # Lightweight stubs – just enough to exercise ReplicaStateContainer logic # without importing all of Ray Serve. # --------------------------------------------------------------------------- @DataClass(frozen=True) class DeploymentID: name: str app_name: str = "default" def __hash__(self): return hash((self.name, self.app_name)) @DataClass(frozen=True) class ReplicaID: unique_id: str deployment_id: DeploymentID def __hash__(self): return hash((self.unique_id, self.deployment_id)) class ReplicaState(str, Enum): STARTING = "STARTING" UPDATING = "UPDATING" RECOVERING = "RECOVERING" RUNNING = "RUNNING" STOPPING = "STOPPING" PENDING_MIGRATION = "PENDING_MIGRATION" ALL_REPLICA_STATES = list(ReplicaState) class _StubActorDetails: """Minimal stand-in for ReplicaDetails (pydantic model in production). Deliberately uses __slots__ to keep memory footprint realistic but light. """ __slots__ = ("state", "replica_id", "node_id") def __init__(self, state: ReplicaState, replica_id: str): self.state = state self.replica_id = replica_id self.node_id = "node-0" def dict(self): return { "state": self.state, "replica_id": self.replica_id, "node_id": self.node_id, } class FakeReplica: """Minimal stand-in for DeploymentReplica. Implements only the surface area touched by the container and stop_replicas. """ def __init__(self, replica_id: ReplicaID, state: ReplicaState): self._replica_id = replica_id self._actor_details = _StubActorDetails(state, replica_id.unique_id) self._update_state_calls = 0 @Property def replica_id(self) -> ReplicaID: return self._replica_id @Property def actor_details(self): return self._actor_details def update_state(self, state: ReplicaState) -> None: """Mirrors update_actor_details; rebuilds the details object.""" self._actor_details = _StubActorDetails(state, self._replica_id.unique_id) self._update_state_calls += 1 # --------------------------------------------------------------------------- # Container implementations # --------------------------------------------------------------------------- class _ContainerBase: """Shared helpers.""" def __init__(self): self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list) def add(self, state: ReplicaState, replica: FakeReplica): replica.update_state(state) self._replicas[state].append(replica) def pop( self, exclude_version=None, states=None, max_replicas=math.inf, ) -> List[FakeReplica]: if states is None: states = ALL_REPLICA_STATES replicas = [] for state in states: popped = [] remaining = [] for replica in self._replicas[state]: if len(replicas) + len(popped) == max_replicas: remaining.append(replica) else: popped.append(replica) self._replicas[state] = remaining replicas.extend(popped) return replicas class OldContainer(_ContainerBase): """Original stop_replicas: pop everything, re-add non-matching.""" def stop_replicas(self, replicas_to_stop: Set[ReplicaID]): stopped = [] for replica in self.pop(): if replica.replica_id in replicas_to_stop: # In production this calls _stop_replica(); we just record it. stopped.append(replica) else: self.add(replica.actor_details.state, replica) return stopped class NewContainer(_ContainerBase): """New stop_replicas: single-pass remove_many by ID set.""" def remove(self, replica_ids) -> List[FakeReplica]: replica_ids = set(replica_ids) removed = [] remaining_to_find = len(replica_ids) for state in ALL_REPLICA_STATES: if remaining_to_find == 0: break found_any = False remaining = [] for replica in self._replicas[state]: if remaining_to_find > 0 and replica.replica_id in replica_ids: removed.append(replica) remaining_to_find -= 1 found_any = True else: remaining.append(replica) if found_any: self._replicas[state] = remaining return removed def stop_replicas(self, replicas_to_stop: Set[ReplicaID]): return self.remove(replicas_to_stop) # --------------------------------------------------------------------------- # Benchmark harness # --------------------------------------------------------------------------- DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app") WARMUP_ROUNDS = 3 MEASURE_ROUNDS = 20 def _make_replicas(n: int) -> List[FakeReplica]: return [ FakeReplica( ReplicaID(f"r-{i}", DEPLOYMENT_ID), ReplicaState.RUNNING, ) for i in range(n) ] def _fill_container(container, replicas: List[FakeReplica]): for r in replicas: container.add(ReplicaState.RUNNING, r) def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]: chosen = random.sample(replicas, k) return {r.replica_id for r in chosen} def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict: """Returns dict with median, p99, mean latency in microseconds.""" replicas_master = _make_replicas(n_replicas) targets = _pick_targets(replicas_master, num_to_stop) timings = [] for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS): c = container_cls() _fill_container(c, replicas_master) gc.disable() t0 = time.perf_counter_ns() c.stop_replicas(targets) t1 = time.perf_counter_ns() gc.enable() if rnd >= WARMUP_ROUNDS: timings.append((t1 - t0) / 1_000) # ns -> µs timings.sort() return { "median_us": statistics.median(timings), "p99_us": timings[int(len(timings) * 0.99)], "mean_us": statistics.mean(timings), } def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict: """Returns dict with peak memory allocation delta in KiB. Averages over multiple runs to reduce noise. """ NUM_MEM_RUNS = 5 deltas = [] for _ in range(NUM_MEM_RUNS): replicas_master = _make_replicas(n_replicas) targets = _pick_targets(replicas_master, num_to_stop) c = container_cls() _fill_container(c, replicas_master) gc.collect() tracemalloc.start() snap_before = tracemalloc.take_snapshot() c.stop_replicas(targets) snap_after = tracemalloc.take_snapshot() tracemalloc.stop() stats = snap_after.compare_to(snap_before, "lineno") delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0) deltas.append(delta_bytes / 1024) return {"peak_delta_kib": statistics.median(deltas)} def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict: """Returns update_state call counts for old vs new.""" replicas_old = _make_replicas(n_replicas) replicas_new = _make_replicas(n_replicas) targets = _pick_targets(replicas_old, num_to_stop) c_old = OldContainer() _fill_container(c_old, replicas_old) for r in replicas_old: r._update_state_calls = 0 c_new = NewContainer() _fill_container(c_new, replicas_new) for r in replicas_new: r._update_state_calls = 0 c_old.stop_replicas(targets) c_new.stop_replicas(targets) old_calls = sum(r._update_state_calls for r in replicas_old) new_calls = sum(r._update_state_calls for r in replicas_new) return {"old": old_calls, "new": new_calls} def run_scenario(label: str, replica_counts: List[int], stop_fn): """Run a full latency + memory + update_state scenario. Args: label: human-readable description (e.g. "stopping 2 replicas"). replica_counts: list of total replica counts to sweep. stop_fn: callable(n) -> number of replicas to stop for that n. """ hdr = ( f"{'Replicas':>8} {'k':>5}" f" │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}" f" │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}" f" │ {'Old upd':>8} {'New upd':>8}" ) sep = "─" * len(hdr) print() print(f" {label}") print(sep) print(hdr) print(sep) for n in replica_counts: k = stop_fn(n) old_lat = bench_latency(OldContainer, n, k) new_lat = bench_latency(NewContainer, n, k) speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf") old_mem = bench_memory(OldContainer, n, k) new_mem = bench_memory(NewContainer, n, k) saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"] us = bench_update_state_calls(n, k) print( f"{n:>8} {k:>5}" f" │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x" f" │ {old_mem['peak_delta_kib']:>7.1f} {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}" f" │ {us['old']:>8,} {us['new']:>8,}" ) print(sep) def main(): replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096] print() print("=" * 72) print(" stop_replicas() micro-benchmark") print(" Old = pop all + re-add non-matching (original)") print(" New = single-pass remove by ID set (optimized)") print("=" * 72) # Scenario 1: stop 2 replicas (typical downscale) run_scenario( "SCENARIO 1: Downscale — stop 2 out of N (typical)", replica_counts, stop_fn=lambda n: 2, ) # Scenario 2: stop all replicas (full teardown) run_scenario( "SCENARIO 2: Teardown — stop ALL N replicas (worst case for old per-ID approach)", replica_counts, stop_fn=lambda n: n, ) # Scenario 3: stop 10% of replicas run_scenario( "SCENARIO 3: Moderate downscale — stop 10% of N", [64, 128, 256, 512, 1024, 2048, 4096], stop_fn=lambda n: max(1, n // 10), ) print() if __name__ == "__main__": main() ``` </details> **Scenario 1 — Downscale: stop 2 out of N (typical)** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | Old `upd_state` | New `upd_state` | |----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------| | 64 | 2 | 53.6 | 8.4 | 6.4x | 5.1 | 1.2 | 3.8 | 62 | 0 | | 256 | 2 | 211.3 | 62.4 | 3.4x | 17.2 | 2.8 | 14.3 | 254 | 0 | | 1024 | 2 | 852.2 | 295.2 | 2.9x | 65.7 | 9.3 | 56.3 | 1,022 | 0 | | 4096 | 2 | 3402.1 | 586.7 | 5.8x | 257.3 | 32.9 | 224.3 | 4,094 | 0 | **Scenario 2 — Teardown: stop ALL N (no regression)** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | |----------|---|--------|--------|---------|---------|---------|-----------| | 256 | 256 | 98.8 | 84.6 | 1.2x | 1.2 | 0.7 | 0.5 | | 1024 | 1024 | 402.0 | 340.6 | 1.2x | 1.2 | 0.7 | 0.5 | | 4096 | 4096 | 1635.7 | 1378.4 | 1.2x | 1.2 | 0.7 | 0.5 | Memory is flat and nearly identical for both — when stopping everything, neither approach re-adds anything. **Scenario 3 — Moderate downscale: stop 10% of N** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | Old `upd_state` | New `upd_state` | |----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------| | 256 | 25 | 205.3 | 75.9 | 2.7x | 15.6 | 2.5 | 13.1 | 231 | 0 | | 1024 | 102 | 828.2 | 319.5 | 2.6x | 59.2 | 8.3 | 50.9 | 922 | 0 | | 4096 | 409 | 3301.4 | 1299.3 | 2.5x | 235.0 | 32.9 | 202.1 | 3,687 | 0 | The old code's memory grows linearly with N because `pop()` allocates a temporary list of all replicas, then `add()` rebuilds `ReplicaDetails` for each re-inserted one. The new code only allocates a `remaining` list for states where a match is found, and never touches non-matching replicas. related to ray-project#60680 --------- Signed-off-by: abrar <abrar@anyscale.com> Signed-off-by: Adel Nour <ans9868@nyu.edu>
…roject#60832) `stop_replicas()` pops **every** replica across all 7 states, checks set membership, and re-adds the vast majority back. Each re-add triggers `update_actor_details()` which rebuilds a `ReplicaDetails` pydantic object. When stopping 2 out of 4096 replicas, 4094 replicas get needlessly popped, rebuilt, and re-added. ### Fix Add a `remove(replica_ids)` method to `ReplicaStateContainer` that performs a single O(N) pass with O(1) set lookups. Non-matching replicas stay in place — no re-add, no update_state call. Early-exits once all targets are found, and only rebuilds the list for states where a match was found. ### Benchmark results <details> <summary> Benchmark script - AI </summary> ```python """Micro-benchmark: stop_replicas() pop-all vs selective-remove. Measures latency and peak memory when stopping a small fraction of replicas from a ReplicaStateContainer, comparing the old approach (pop all + re-add) against the new approach (selective remove by ID). Usage: python bench_stop_replicas.py """ import gc import math import random import statistics import time import tracemalloc from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Set # --------------------------------------------------------------------------- # Lightweight stubs – just enough to exercise ReplicaStateContainer logic # without importing all of Ray Serve. # --------------------------------------------------------------------------- @DataClass(frozen=True) class DeploymentID: name: str app_name: str = "default" def __hash__(self): return hash((self.name, self.app_name)) @DataClass(frozen=True) class ReplicaID: unique_id: str deployment_id: DeploymentID def __hash__(self): return hash((self.unique_id, self.deployment_id)) class ReplicaState(str, Enum): STARTING = "STARTING" UPDATING = "UPDATING" RECOVERING = "RECOVERING" RUNNING = "RUNNING" STOPPING = "STOPPING" PENDING_MIGRATION = "PENDING_MIGRATION" ALL_REPLICA_STATES = list(ReplicaState) class _StubActorDetails: """Minimal stand-in for ReplicaDetails (pydantic model in production). Deliberately uses __slots__ to keep memory footprint realistic but light. """ __slots__ = ("state", "replica_id", "node_id") def __init__(self, state: ReplicaState, replica_id: str): self.state = state self.replica_id = replica_id self.node_id = "node-0" def dict(self): return { "state": self.state, "replica_id": self.replica_id, "node_id": self.node_id, } class FakeReplica: """Minimal stand-in for DeploymentReplica. Implements only the surface area touched by the container and stop_replicas. """ def __init__(self, replica_id: ReplicaID, state: ReplicaState): self._replica_id = replica_id self._actor_details = _StubActorDetails(state, replica_id.unique_id) self._update_state_calls = 0 @Property def replica_id(self) -> ReplicaID: return self._replica_id @Property def actor_details(self): return self._actor_details def update_state(self, state: ReplicaState) -> None: """Mirrors update_actor_details; rebuilds the details object.""" self._actor_details = _StubActorDetails(state, self._replica_id.unique_id) self._update_state_calls += 1 # --------------------------------------------------------------------------- # Container implementations # --------------------------------------------------------------------------- class _ContainerBase: """Shared helpers.""" def __init__(self): self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list) def add(self, state: ReplicaState, replica: FakeReplica): replica.update_state(state) self._replicas[state].append(replica) def pop( self, exclude_version=None, states=None, max_replicas=math.inf, ) -> List[FakeReplica]: if states is None: states = ALL_REPLICA_STATES replicas = [] for state in states: popped = [] remaining = [] for replica in self._replicas[state]: if len(replicas) + len(popped) == max_replicas: remaining.append(replica) else: popped.append(replica) self._replicas[state] = remaining replicas.extend(popped) return replicas class OldContainer(_ContainerBase): """Original stop_replicas: pop everything, re-add non-matching.""" def stop_replicas(self, replicas_to_stop: Set[ReplicaID]): stopped = [] for replica in self.pop(): if replica.replica_id in replicas_to_stop: # In production this calls _stop_replica(); we just record it. stopped.append(replica) else: self.add(replica.actor_details.state, replica) return stopped class NewContainer(_ContainerBase): """New stop_replicas: single-pass remove_many by ID set.""" def remove(self, replica_ids) -> List[FakeReplica]: replica_ids = set(replica_ids) removed = [] remaining_to_find = len(replica_ids) for state in ALL_REPLICA_STATES: if remaining_to_find == 0: break found_any = False remaining = [] for replica in self._replicas[state]: if remaining_to_find > 0 and replica.replica_id in replica_ids: removed.append(replica) remaining_to_find -= 1 found_any = True else: remaining.append(replica) if found_any: self._replicas[state] = remaining return removed def stop_replicas(self, replicas_to_stop: Set[ReplicaID]): return self.remove(replicas_to_stop) # --------------------------------------------------------------------------- # Benchmark harness # --------------------------------------------------------------------------- DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app") WARMUP_ROUNDS = 3 MEASURE_ROUNDS = 20 def _make_replicas(n: int) -> List[FakeReplica]: return [ FakeReplica( ReplicaID(f"r-{i}", DEPLOYMENT_ID), ReplicaState.RUNNING, ) for i in range(n) ] def _fill_container(container, replicas: List[FakeReplica]): for r in replicas: container.add(ReplicaState.RUNNING, r) def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]: chosen = random.sample(replicas, k) return {r.replica_id for r in chosen} def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict: """Returns dict with median, p99, mean latency in microseconds.""" replicas_master = _make_replicas(n_replicas) targets = _pick_targets(replicas_master, num_to_stop) timings = [] for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS): c = container_cls() _fill_container(c, replicas_master) gc.disable() t0 = time.perf_counter_ns() c.stop_replicas(targets) t1 = time.perf_counter_ns() gc.enable() if rnd >= WARMUP_ROUNDS: timings.append((t1 - t0) / 1_000) # ns -> µs timings.sort() return { "median_us": statistics.median(timings), "p99_us": timings[int(len(timings) * 0.99)], "mean_us": statistics.mean(timings), } def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict: """Returns dict with peak memory allocation delta in KiB. Averages over multiple runs to reduce noise. """ NUM_MEM_RUNS = 5 deltas = [] for _ in range(NUM_MEM_RUNS): replicas_master = _make_replicas(n_replicas) targets = _pick_targets(replicas_master, num_to_stop) c = container_cls() _fill_container(c, replicas_master) gc.collect() tracemalloc.start() snap_before = tracemalloc.take_snapshot() c.stop_replicas(targets) snap_after = tracemalloc.take_snapshot() tracemalloc.stop() stats = snap_after.compare_to(snap_before, "lineno") delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0) deltas.append(delta_bytes / 1024) return {"peak_delta_kib": statistics.median(deltas)} def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict: """Returns update_state call counts for old vs new.""" replicas_old = _make_replicas(n_replicas) replicas_new = _make_replicas(n_replicas) targets = _pick_targets(replicas_old, num_to_stop) c_old = OldContainer() _fill_container(c_old, replicas_old) for r in replicas_old: r._update_state_calls = 0 c_new = NewContainer() _fill_container(c_new, replicas_new) for r in replicas_new: r._update_state_calls = 0 c_old.stop_replicas(targets) c_new.stop_replicas(targets) old_calls = sum(r._update_state_calls for r in replicas_old) new_calls = sum(r._update_state_calls for r in replicas_new) return {"old": old_calls, "new": new_calls} def run_scenario(label: str, replica_counts: List[int], stop_fn): """Run a full latency + memory + update_state scenario. Args: label: human-readable description (e.g. "stopping 2 replicas"). replica_counts: list of total replica counts to sweep. stop_fn: callable(n) -> number of replicas to stop for that n. """ hdr = ( f"{'Replicas':>8} {'k':>5}" f" │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}" f" │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}" f" │ {'Old upd':>8} {'New upd':>8}" ) sep = "─" * len(hdr) print() print(f" {label}") print(sep) print(hdr) print(sep) for n in replica_counts: k = stop_fn(n) old_lat = bench_latency(OldContainer, n, k) new_lat = bench_latency(NewContainer, n, k) speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf") old_mem = bench_memory(OldContainer, n, k) new_mem = bench_memory(NewContainer, n, k) saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"] us = bench_update_state_calls(n, k) print( f"{n:>8} {k:>5}" f" │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x" f" │ {old_mem['peak_delta_kib']:>7.1f} {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}" f" │ {us['old']:>8,} {us['new']:>8,}" ) print(sep) def main(): replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096] print() print("=" * 72) print(" stop_replicas() micro-benchmark") print(" Old = pop all + re-add non-matching (original)") print(" New = single-pass remove by ID set (optimized)") print("=" * 72) # Scenario 1: stop 2 replicas (typical downscale) run_scenario( "SCENARIO 1: Downscale — stop 2 out of N (typical)", replica_counts, stop_fn=lambda n: 2, ) # Scenario 2: stop all replicas (full teardown) run_scenario( "SCENARIO 2: Teardown — stop ALL N replicas (worst case for old per-ID approach)", replica_counts, stop_fn=lambda n: n, ) # Scenario 3: stop 10% of replicas run_scenario( "SCENARIO 3: Moderate downscale — stop 10% of N", [64, 128, 256, 512, 1024, 2048, 4096], stop_fn=lambda n: max(1, n // 10), ) print() if __name__ == "__main__": main() ``` </details> **Scenario 1 — Downscale: stop 2 out of N (typical)** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | Old `upd_state` | New `upd_state` | |----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------| | 64 | 2 | 53.6 | 8.4 | 6.4x | 5.1 | 1.2 | 3.8 | 62 | 0 | | 256 | 2 | 211.3 | 62.4 | 3.4x | 17.2 | 2.8 | 14.3 | 254 | 0 | | 1024 | 2 | 852.2 | 295.2 | 2.9x | 65.7 | 9.3 | 56.3 | 1,022 | 0 | | 4096 | 2 | 3402.1 | 586.7 | 5.8x | 257.3 | 32.9 | 224.3 | 4,094 | 0 | **Scenario 2 — Teardown: stop ALL N (no regression)** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | |----------|---|--------|--------|---------|---------|---------|-----------| | 256 | 256 | 98.8 | 84.6 | 1.2x | 1.2 | 0.7 | 0.5 | | 1024 | 1024 | 402.0 | 340.6 | 1.2x | 1.2 | 0.7 | 0.5 | | 4096 | 4096 | 1635.7 | 1378.4 | 1.2x | 1.2 | 0.7 | 0.5 | Memory is flat and nearly identical for both — when stopping everything, neither approach re-adds anything. **Scenario 3 — Moderate downscale: stop 10% of N** | Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved KiB | Old `upd_state` | New `upd_state` | |----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------| | 256 | 25 | 205.3 | 75.9 | 2.7x | 15.6 | 2.5 | 13.1 | 231 | 0 | | 1024 | 102 | 828.2 | 319.5 | 2.6x | 59.2 | 8.3 | 50.9 | 922 | 0 | | 4096 | 409 | 3301.4 | 1299.3 | 2.5x | 235.0 | 32.9 | 202.1 | 3,687 | 0 | The old code's memory grows linearly with N because `pop()` allocates a temporary list of all replicas, then `add()` rebuilds `ReplicaDetails` for each re-inserted one. The new code only allocates a `remaining` list for states where a match is found, and never touches non-matching replicas. related to ray-project#60680 --------- Signed-off-by: abrar <abrar@anyscale.com>
stop_replicas()pops every replica across all 7 states, checks set membership, and re-adds the vast majority back. Each re-add triggersupdate_actor_details()which rebuilds aReplicaDetailspydantic object. When stopping 2 out of 4096 replicas, 4094 replicas get needlessly popped, rebuilt, and re-added.Fix
Add a
remove(replica_ids)method toReplicaStateContainerthat performs a single O(N) pass with O(1) set lookups. Non-matching replicas stay in place — no re-add, no update_state call. Early-exits once all targets are found, and only rebuilds the list for states where a match was found.Benchmark results
Benchmark script - AI
Scenario 1 — Downscale: stop 2 out of N (typical)
upd_stateupd_stateScenario 2 — Teardown: stop ALL N (no regression)
Memory is flat and nearly identical for both — when stopping everything, neither approach re-adds anything.
Scenario 3 — Moderate downscale: stop 10% of N
upd_stateupd_stateThe old code's memory grows linearly with N because
pop()allocates a temporary list of all replicas, thenadd()rebuildsReplicaDetailsfor each re-inserted one. The new code only allocates aremaininglist for states where a match is found, and never touches non-matching replicas.related to #60680