Skip to content

[Serve] Optimize stop_replicas() to avoid pop-all/re-add cycle#60832

Merged
abrarsheikh merged 2 commits into
masterfrom
optimize-stop-replicas
Feb 9, 2026
Merged

[Serve] Optimize stop_replicas() to avoid pop-all/re-add cycle#60832
abrarsheikh merged 2 commits into
masterfrom
optimize-stop-replicas

Conversation

@abrarsheikh

@abrarsheikh abrarsheikh commented Feb 7, 2026

Copy link
Copy Markdown
Contributor

stop_replicas() pops every replica across all 7 states, checks set membership, and re-adds the vast majority back. Each re-add triggers update_actor_details() which rebuilds a ReplicaDetails pydantic object. When stopping 2 out of 4096 replicas, 4094 replicas get needlessly popped, rebuilt, and re-added.

Fix

Add a remove(replica_ids) method to ReplicaStateContainer that performs a single O(N) pass with O(1) set lookups. Non-matching replicas stay in place — no re-add, no update_state call. Early-exits once all targets are found, and only rebuilds the list for states where a match was found.

Benchmark results

Benchmark script - AI
"""Micro-benchmark: stop_replicas() pop-all vs selective-remove.

Measures latency and peak memory when stopping a small fraction of replicas
from a ReplicaStateContainer, comparing the old approach (pop all + re-add)
against the new approach (selective remove by ID).

Usage:
    python bench_stop_replicas.py
"""

import gc
import math
import random
import statistics
import time
import tracemalloc
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set


# ---------------------------------------------------------------------------
# Lightweight stubs – just enough to exercise ReplicaStateContainer logic
# without importing all of Ray Serve.
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class DeploymentID:
    name: str
    app_name: str = "default"

    def __hash__(self):
        return hash((self.name, self.app_name))


@dataclass(frozen=True)
class ReplicaID:
    unique_id: str
    deployment_id: DeploymentID

    def __hash__(self):
        return hash((self.unique_id, self.deployment_id))


class ReplicaState(str, Enum):
    STARTING = "STARTING"
    UPDATING = "UPDATING"
    RECOVERING = "RECOVERING"
    RUNNING = "RUNNING"
    STOPPING = "STOPPING"
    PENDING_MIGRATION = "PENDING_MIGRATION"


ALL_REPLICA_STATES = list(ReplicaState)


class _StubActorDetails:
    """Minimal stand-in for ReplicaDetails (pydantic model in production).

    Deliberately uses __slots__ to keep memory footprint realistic but light.
    """

    __slots__ = ("state", "replica_id", "node_id")

    def __init__(self, state: ReplicaState, replica_id: str):
        self.state = state
        self.replica_id = replica_id
        self.node_id = "node-0"

    def dict(self):
        return {
            "state": self.state,
            "replica_id": self.replica_id,
            "node_id": self.node_id,
        }


class FakeReplica:
    """Minimal stand-in for DeploymentReplica.

    Implements only the surface area touched by the container and stop_replicas.
    """

    def __init__(self, replica_id: ReplicaID, state: ReplicaState):
        self._replica_id = replica_id
        self._actor_details = _StubActorDetails(state, replica_id.unique_id)
        self._update_state_calls = 0

    @property
    def replica_id(self) -> ReplicaID:
        return self._replica_id

    @property
    def actor_details(self):
        return self._actor_details

    def update_state(self, state: ReplicaState) -> None:
        """Mirrors update_actor_details; rebuilds the details object."""
        self._actor_details = _StubActorDetails(state, self._replica_id.unique_id)
        self._update_state_calls += 1


# ---------------------------------------------------------------------------
# Container implementations
# ---------------------------------------------------------------------------


class _ContainerBase:
    """Shared helpers."""

    def __init__(self):
        self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list)

    def add(self, state: ReplicaState, replica: FakeReplica):
        replica.update_state(state)
        self._replicas[state].append(replica)

    def pop(
        self,
        exclude_version=None,
        states=None,
        max_replicas=math.inf,
    ) -> List[FakeReplica]:
        if states is None:
            states = ALL_REPLICA_STATES
        replicas = []
        for state in states:
            popped = []
            remaining = []
            for replica in self._replicas[state]:
                if len(replicas) + len(popped) == max_replicas:
                    remaining.append(replica)
                else:
                    popped.append(replica)
            self._replicas[state] = remaining
            replicas.extend(popped)
        return replicas


class OldContainer(_ContainerBase):
    """Original stop_replicas: pop everything, re-add non-matching."""

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        stopped = []
        for replica in self.pop():
            if replica.replica_id in replicas_to_stop:
                # In production this calls _stop_replica(); we just record it.
                stopped.append(replica)
            else:
                self.add(replica.actor_details.state, replica)
        return stopped


class NewContainer(_ContainerBase):
    """New stop_replicas: single-pass remove_many by ID set."""

    def remove(self, replica_ids) -> List[FakeReplica]:
        replica_ids = set(replica_ids)
        removed = []
        remaining_to_find = len(replica_ids)
        for state in ALL_REPLICA_STATES:
            if remaining_to_find == 0:
                break
            found_any = False
            remaining = []
            for replica in self._replicas[state]:
                if remaining_to_find > 0 and replica.replica_id in replica_ids:
                    removed.append(replica)
                    remaining_to_find -= 1
                    found_any = True
                else:
                    remaining.append(replica)
            if found_any:
                self._replicas[state] = remaining
        return removed

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        return self.remove(replicas_to_stop)


# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------

DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app")
WARMUP_ROUNDS = 3
MEASURE_ROUNDS = 20


def _make_replicas(n: int) -> List[FakeReplica]:
    return [
        FakeReplica(
            ReplicaID(f"r-{i}", DEPLOYMENT_ID),
            ReplicaState.RUNNING,
        )
        for i in range(n)
    ]


def _fill_container(container, replicas: List[FakeReplica]):
    for r in replicas:
        container.add(ReplicaState.RUNNING, r)


def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]:
    chosen = random.sample(replicas, k)
    return {r.replica_id for r in chosen}


def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with median, p99, mean latency in microseconds."""
    replicas_master = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_master, num_to_stop)

    timings = []

    for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS):
        c = container_cls()
        _fill_container(c, replicas_master)

        gc.disable()
        t0 = time.perf_counter_ns()
        c.stop_replicas(targets)
        t1 = time.perf_counter_ns()
        gc.enable()

        if rnd >= WARMUP_ROUNDS:
            timings.append((t1 - t0) / 1_000)  # ns -> µs

    timings.sort()
    return {
        "median_us": statistics.median(timings),
        "p99_us": timings[int(len(timings) * 0.99)],
        "mean_us": statistics.mean(timings),
    }


def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with peak memory allocation delta in KiB.

    Averages over multiple runs to reduce noise.
    """
    NUM_MEM_RUNS = 5
    deltas = []
    for _ in range(NUM_MEM_RUNS):
        replicas_master = _make_replicas(n_replicas)
        targets = _pick_targets(replicas_master, num_to_stop)

        c = container_cls()
        _fill_container(c, replicas_master)

        gc.collect()
        tracemalloc.start()

        snap_before = tracemalloc.take_snapshot()
        c.stop_replicas(targets)
        snap_after = tracemalloc.take_snapshot()

        tracemalloc.stop()

        stats = snap_after.compare_to(snap_before, "lineno")
        delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0)
        deltas.append(delta_bytes / 1024)

    return {"peak_delta_kib": statistics.median(deltas)}


def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict:
    """Returns update_state call counts for old vs new."""
    replicas_old = _make_replicas(n_replicas)
    replicas_new = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_old, num_to_stop)

    c_old = OldContainer()
    _fill_container(c_old, replicas_old)
    for r in replicas_old:
        r._update_state_calls = 0

    c_new = NewContainer()
    _fill_container(c_new, replicas_new)
    for r in replicas_new:
        r._update_state_calls = 0

    c_old.stop_replicas(targets)
    c_new.stop_replicas(targets)

    old_calls = sum(r._update_state_calls for r in replicas_old)
    new_calls = sum(r._update_state_calls for r in replicas_new)
    return {"old": old_calls, "new": new_calls}


def run_scenario(label: str, replica_counts: List[int], stop_fn):
    """Run a full latency + memory + update_state scenario.

    Args:
        label: human-readable description (e.g. "stopping 2 replicas").
        replica_counts: list of total replica counts to sweep.
        stop_fn: callable(n) -> number of replicas to stop for that n.
    """
    hdr = (
        f"{'Replicas':>8} {'k':>5}"
        f"  │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}"
        f"  │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}"
        f"  │ {'Old upd':>8} {'New upd':>8}"
    )
    sep = "─" * len(hdr)

    print()
    print(f"  {label}")
    print(sep)
    print(hdr)
    print(sep)

    for n in replica_counts:
        k = stop_fn(n)
        old_lat = bench_latency(OldContainer, n, k)
        new_lat = bench_latency(NewContainer, n, k)
        speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf")
        old_mem = bench_memory(OldContainer, n, k)
        new_mem = bench_memory(NewContainer, n, k)
        saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"]
        us = bench_update_state_calls(n, k)

        print(
            f"{n:>8} {k:>5}"
            f"  │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x"
            f"  │ {old_mem['peak_delta_kib']:>7.1f}  {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}"
            f"  │ {us['old']:>8,} {us['new']:>8,}"
        )

    print(sep)


def main():
    replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    print()
    print("=" * 72)
    print("  stop_replicas() micro-benchmark")
    print("  Old = pop all + re-add non-matching  (original)")
    print("  New = single-pass remove by ID set   (optimized)")
    print("=" * 72)

    # Scenario 1: stop 2 replicas (typical downscale)
    run_scenario(
        "SCENARIO 1: Downscale — stop 2 out of N  (typical)",
        replica_counts,
        stop_fn=lambda n: 2,
    )

    # Scenario 2: stop all replicas (full teardown)
    run_scenario(
        "SCENARIO 2: Teardown — stop ALL N replicas  (worst case for old per-ID approach)",
        replica_counts,
        stop_fn=lambda n: n,
    )

    # Scenario 3: stop 10% of replicas
    run_scenario(
        "SCENARIO 3: Moderate downscale — stop 10% of N",
        [64, 128, 256, 512, 1024, 2048, 4096],
        stop_fn=lambda n: max(1, n // 10),
    )

    print()


if __name__ == "__main__":
    main()

Scenario 1 — Downscale: stop 2 out of N (typical)

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB Old upd_state New upd_state
64 2 53.6 8.4 6.4x 5.1 1.2 3.8 62 0
256 2 211.3 62.4 3.4x 17.2 2.8 14.3 254 0
1024 2 852.2 295.2 2.9x 65.7 9.3 56.3 1,022 0
4096 2 3402.1 586.7 5.8x 257.3 32.9 224.3 4,094 0

Scenario 2 — Teardown: stop ALL N (no regression)

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB
256 256 98.8 84.6 1.2x 1.2 0.7 0.5
1024 1024 402.0 340.6 1.2x 1.2 0.7 0.5
4096 4096 1635.7 1378.4 1.2x 1.2 0.7 0.5

Memory is flat and nearly identical for both — when stopping everything, neither approach re-adds anything.

Scenario 3 — Moderate downscale: stop 10% of N

Replicas k Old µs New µs Speedup Old KiB New KiB Saved KiB Old upd_state New upd_state
256 25 205.3 75.9 2.7x 15.6 2.5 13.1 231 0
1024 102 828.2 319.5 2.6x 59.2 8.3 50.9 922 0
4096 409 3301.4 1299.3 2.5x 235.0 32.9 202.1 3,687 0

The old code's memory grows linearly with N because pop() allocates a temporary list of all replicas, then add() rebuilds ReplicaDetails for each re-inserted one. The new code only allocates a remaining list for states where a match is found, and never touches non-matching replicas.

related to #60680

Signed-off-by: abrar <abrar@anyscale.com>
@abrarsheikh abrarsheikh requested a review from a team as a code owner February 7, 2026 07:53

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes stop_replicas to avoid a costly cycle of popping and re-adding all replicas when only a few need to be stopped. This is achieved by introducing a remove(replica_id) method in ReplicaStateContainer and using it in a loop within stop_replicas. The benchmarks clearly show significant performance improvements for stopping a small number of replicas.

My main concern, detailed in a specific comment, is a potential performance regression when stopping a large number of replicas, due to the O(k*N) complexity of the new approach. I've suggested an alternative that would be efficient for all cases.

Comment thread python/ray/serve/_private/deployment_state.py Outdated
Signed-off-by: abrar <abrar@anyscale.com>
@abrarsheikh abrarsheikh added the go add ONLY when ready to merge, run all tests label Feb 7, 2026
Comment on lines +1650 to +1667
replica_ids = set(replica_ids)
removed = []
remaining_to_find = len(replica_ids)
for state in ALL_REPLICA_STATES:
if remaining_to_find == 0:
break
found_any = False
remaining = []
for replica in self._replicas[state]:
if remaining_to_find > 0 and replica.replica_id in replica_ids:
removed.append(replica)
remaining_to_find -= 1
found_any = True
else:
remaining.append(replica)
if found_any:
self._replicas[state] = remaining
return removed

@harshit-anyscale harshit-anyscale Feb 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would still require us to iterate over the states first, and then each replica, and then create state-wise individual list again.

would it be beneficial if we modify this _replicas to be Dict [ReplicaState, Dict[ReplicaId, DeploymentReplica]]? that we can just find the replica id in O(1) and then remove it. it might add some overhead as well, can we check how much it will be?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is optimization to be had there, but I prefer taking baby steps towards that. Will take that as a follow-up. Changing the data structure of _replicas also requires rewriting all other functions in this class.

@abrarsheikh abrarsheikh merged commit 3b69ec3 into master Feb 9, 2026
7 checks passed
@abrarsheikh abrarsheikh deleted the optimize-stop-replicas branch February 9, 2026 17:57
ans9868 pushed a commit to ans9868/ray that referenced this pull request Feb 18, 2026
…roject#60832)

`stop_replicas()` pops **every** replica across all 7 states, checks set
membership, and re-adds the vast majority back. Each re-add triggers
`update_actor_details()` which rebuilds a `ReplicaDetails` pydantic
object. When stopping 2 out of 4096 replicas, 4094 replicas get
needlessly popped, rebuilt, and re-added.

### Fix

Add a `remove(replica_ids)` method to `ReplicaStateContainer` that
performs a single O(N) pass with O(1) set lookups. Non-matching replicas
stay in place — no re-add, no update_state call. Early-exits once all
targets are found, and only rebuilds the list for states where a match
was found.

### Benchmark results

<details>

<summary> Benchmark script - AI </summary>

```python
"""Micro-benchmark: stop_replicas() pop-all vs selective-remove.

Measures latency and peak memory when stopping a small fraction of replicas
from a ReplicaStateContainer, comparing the old approach (pop all + re-add)
against the new approach (selective remove by ID).

Usage:
    python bench_stop_replicas.py
"""

import gc
import math
import random
import statistics
import time
import tracemalloc
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set

# ---------------------------------------------------------------------------
# Lightweight stubs – just enough to exercise ReplicaStateContainer logic
# without importing all of Ray Serve.
# ---------------------------------------------------------------------------

@DataClass(frozen=True)
class DeploymentID:
    name: str
    app_name: str = "default"

    def __hash__(self):
        return hash((self.name, self.app_name))

@DataClass(frozen=True)
class ReplicaID:
    unique_id: str
    deployment_id: DeploymentID

    def __hash__(self):
        return hash((self.unique_id, self.deployment_id))

class ReplicaState(str, Enum):
    STARTING = "STARTING"
    UPDATING = "UPDATING"
    RECOVERING = "RECOVERING"
    RUNNING = "RUNNING"
    STOPPING = "STOPPING"
    PENDING_MIGRATION = "PENDING_MIGRATION"

ALL_REPLICA_STATES = list(ReplicaState)

class _StubActorDetails:
    """Minimal stand-in for ReplicaDetails (pydantic model in production).

    Deliberately uses __slots__ to keep memory footprint realistic but light.
    """

    __slots__ = ("state", "replica_id", "node_id")

    def __init__(self, state: ReplicaState, replica_id: str):
        self.state = state
        self.replica_id = replica_id
        self.node_id = "node-0"

    def dict(self):
        return {
            "state": self.state,
            "replica_id": self.replica_id,
            "node_id": self.node_id,
        }

class FakeReplica:
    """Minimal stand-in for DeploymentReplica.

    Implements only the surface area touched by the container and stop_replicas.
    """

    def __init__(self, replica_id: ReplicaID, state: ReplicaState):
        self._replica_id = replica_id
        self._actor_details = _StubActorDetails(state, replica_id.unique_id)
        self._update_state_calls = 0

    @Property
    def replica_id(self) -> ReplicaID:
        return self._replica_id

    @Property
    def actor_details(self):
        return self._actor_details

    def update_state(self, state: ReplicaState) -> None:
        """Mirrors update_actor_details; rebuilds the details object."""
        self._actor_details = _StubActorDetails(state, self._replica_id.unique_id)
        self._update_state_calls += 1

# ---------------------------------------------------------------------------
# Container implementations
# ---------------------------------------------------------------------------

class _ContainerBase:
    """Shared helpers."""

    def __init__(self):
        self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list)

    def add(self, state: ReplicaState, replica: FakeReplica):
        replica.update_state(state)
        self._replicas[state].append(replica)

    def pop(
        self,
        exclude_version=None,
        states=None,
        max_replicas=math.inf,
    ) -> List[FakeReplica]:
        if states is None:
            states = ALL_REPLICA_STATES
        replicas = []
        for state in states:
            popped = []
            remaining = []
            for replica in self._replicas[state]:
                if len(replicas) + len(popped) == max_replicas:
                    remaining.append(replica)
                else:
                    popped.append(replica)
            self._replicas[state] = remaining
            replicas.extend(popped)
        return replicas

class OldContainer(_ContainerBase):
    """Original stop_replicas: pop everything, re-add non-matching."""

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        stopped = []
        for replica in self.pop():
            if replica.replica_id in replicas_to_stop:
                # In production this calls _stop_replica(); we just record it.
                stopped.append(replica)
            else:
                self.add(replica.actor_details.state, replica)
        return stopped

class NewContainer(_ContainerBase):
    """New stop_replicas: single-pass remove_many by ID set."""

    def remove(self, replica_ids) -> List[FakeReplica]:
        replica_ids = set(replica_ids)
        removed = []
        remaining_to_find = len(replica_ids)
        for state in ALL_REPLICA_STATES:
            if remaining_to_find == 0:
                break
            found_any = False
            remaining = []
            for replica in self._replicas[state]:
                if remaining_to_find > 0 and replica.replica_id in replica_ids:
                    removed.append(replica)
                    remaining_to_find -= 1
                    found_any = True
                else:
                    remaining.append(replica)
            if found_any:
                self._replicas[state] = remaining
        return removed

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        return self.remove(replicas_to_stop)

# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------

DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app")
WARMUP_ROUNDS = 3
MEASURE_ROUNDS = 20

def _make_replicas(n: int) -> List[FakeReplica]:
    return [
        FakeReplica(
            ReplicaID(f"r-{i}", DEPLOYMENT_ID),
            ReplicaState.RUNNING,
        )
        for i in range(n)
    ]

def _fill_container(container, replicas: List[FakeReplica]):
    for r in replicas:
        container.add(ReplicaState.RUNNING, r)

def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]:
    chosen = random.sample(replicas, k)
    return {r.replica_id for r in chosen}

def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with median, p99, mean latency in microseconds."""
    replicas_master = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_master, num_to_stop)

    timings = []

    for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS):
        c = container_cls()
        _fill_container(c, replicas_master)

        gc.disable()
        t0 = time.perf_counter_ns()
        c.stop_replicas(targets)
        t1 = time.perf_counter_ns()
        gc.enable()

        if rnd >= WARMUP_ROUNDS:
            timings.append((t1 - t0) / 1_000)  # ns -> µs

    timings.sort()
    return {
        "median_us": statistics.median(timings),
        "p99_us": timings[int(len(timings) * 0.99)],
        "mean_us": statistics.mean(timings),
    }

def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with peak memory allocation delta in KiB.

    Averages over multiple runs to reduce noise.
    """
    NUM_MEM_RUNS = 5
    deltas = []
    for _ in range(NUM_MEM_RUNS):
        replicas_master = _make_replicas(n_replicas)
        targets = _pick_targets(replicas_master, num_to_stop)

        c = container_cls()
        _fill_container(c, replicas_master)

        gc.collect()
        tracemalloc.start()

        snap_before = tracemalloc.take_snapshot()
        c.stop_replicas(targets)
        snap_after = tracemalloc.take_snapshot()

        tracemalloc.stop()

        stats = snap_after.compare_to(snap_before, "lineno")
        delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0)
        deltas.append(delta_bytes / 1024)

    return {"peak_delta_kib": statistics.median(deltas)}

def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict:
    """Returns update_state call counts for old vs new."""
    replicas_old = _make_replicas(n_replicas)
    replicas_new = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_old, num_to_stop)

    c_old = OldContainer()
    _fill_container(c_old, replicas_old)
    for r in replicas_old:
        r._update_state_calls = 0

    c_new = NewContainer()
    _fill_container(c_new, replicas_new)
    for r in replicas_new:
        r._update_state_calls = 0

    c_old.stop_replicas(targets)
    c_new.stop_replicas(targets)

    old_calls = sum(r._update_state_calls for r in replicas_old)
    new_calls = sum(r._update_state_calls for r in replicas_new)
    return {"old": old_calls, "new": new_calls}

def run_scenario(label: str, replica_counts: List[int], stop_fn):
    """Run a full latency + memory + update_state scenario.

    Args:
        label: human-readable description (e.g. "stopping 2 replicas").
        replica_counts: list of total replica counts to sweep.
        stop_fn: callable(n) -> number of replicas to stop for that n.
    """
    hdr = (
        f"{'Replicas':>8} {'k':>5}"
        f"  │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}"
        f"  │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}"
        f"  │ {'Old upd':>8} {'New upd':>8}"
    )
    sep = "─" * len(hdr)

    print()
    print(f"  {label}")
    print(sep)
    print(hdr)
    print(sep)

    for n in replica_counts:
        k = stop_fn(n)
        old_lat = bench_latency(OldContainer, n, k)
        new_lat = bench_latency(NewContainer, n, k)
        speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf")
        old_mem = bench_memory(OldContainer, n, k)
        new_mem = bench_memory(NewContainer, n, k)
        saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"]
        us = bench_update_state_calls(n, k)

        print(
            f"{n:>8} {k:>5}"
            f"  │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x"
            f"  │ {old_mem['peak_delta_kib']:>7.1f}  {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}"
            f"  │ {us['old']:>8,} {us['new']:>8,}"
        )

    print(sep)

def main():
    replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    print()
    print("=" * 72)
    print("  stop_replicas() micro-benchmark")
    print("  Old = pop all + re-add non-matching  (original)")
    print("  New = single-pass remove by ID set   (optimized)")
    print("=" * 72)

    # Scenario 1: stop 2 replicas (typical downscale)
    run_scenario(
        "SCENARIO 1: Downscale — stop 2 out of N  (typical)",
        replica_counts,
        stop_fn=lambda n: 2,
    )

    # Scenario 2: stop all replicas (full teardown)
    run_scenario(
        "SCENARIO 2: Teardown — stop ALL N replicas  (worst case for old per-ID approach)",
        replica_counts,
        stop_fn=lambda n: n,
    )

    # Scenario 3: stop 10% of replicas
    run_scenario(
        "SCENARIO 3: Moderate downscale — stop 10% of N",
        [64, 128, 256, 512, 1024, 2048, 4096],
        stop_fn=lambda n: max(1, n // 10),
    )

    print()

if __name__ == "__main__":
    main()

```

</details>

**Scenario 1 — Downscale: stop 2 out of N (typical)**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB | Old `upd_state` | New `upd_state` |

|----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------|
| 64 | 2 | 53.6 | 8.4 | 6.4x | 5.1 | 1.2 | 3.8 | 62 | 0 |
| 256 | 2 | 211.3 | 62.4 | 3.4x | 17.2 | 2.8 | 14.3 | 254 | 0 |
| 1024 | 2 | 852.2 | 295.2 | 2.9x | 65.7 | 9.3 | 56.3 | 1,022 | 0 |
| 4096 | 2 | 3402.1 | 586.7 | 5.8x | 257.3 | 32.9 | 224.3 | 4,094 | 0 |

**Scenario 2 — Teardown: stop ALL N (no regression)**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB |

|----------|---|--------|--------|---------|---------|---------|-----------|
| 256 | 256 | 98.8 | 84.6 | 1.2x | 1.2 | 0.7 | 0.5 |
| 1024 | 1024 | 402.0 | 340.6 | 1.2x | 1.2 | 0.7 | 0.5 |
| 4096 | 4096 | 1635.7 | 1378.4 | 1.2x | 1.2 | 0.7 | 0.5 |

Memory is flat and nearly identical for both — when stopping everything,
neither approach re-adds anything.

**Scenario 3 — Moderate downscale: stop 10% of N**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB | Old `upd_state` | New `upd_state` |

|----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------|
| 256 | 25 | 205.3 | 75.9 | 2.7x | 15.6 | 2.5 | 13.1 | 231 | 0 |
| 1024 | 102 | 828.2 | 319.5 | 2.6x | 59.2 | 8.3 | 50.9 | 922 | 0 |
| 4096 | 409 | 3301.4 | 1299.3 | 2.5x | 235.0 | 32.9 | 202.1 | 3,687 | 0
|

The old code's memory grows linearly with N because `pop()` allocates a
temporary list of all replicas, then `add()` rebuilds `ReplicaDetails`
for each re-inserted one. The new code only allocates a `remaining` list
for states where a match is found, and never touches non-matching
replicas.

related to ray-project#60680

---------

Signed-off-by: abrar <abrar@anyscale.com>
Signed-off-by: Adel Nour <ans9868@nyu.edu>
Aydin-ab pushed a commit to kunling-anyscale/ray that referenced this pull request Feb 20, 2026
…roject#60832)

`stop_replicas()` pops **every** replica across all 7 states, checks set
membership, and re-adds the vast majority back. Each re-add triggers
`update_actor_details()` which rebuilds a `ReplicaDetails` pydantic
object. When stopping 2 out of 4096 replicas, 4094 replicas get
needlessly popped, rebuilt, and re-added.

### Fix

Add a `remove(replica_ids)` method to `ReplicaStateContainer` that
performs a single O(N) pass with O(1) set lookups. Non-matching replicas
stay in place — no re-add, no update_state call. Early-exits once all
targets are found, and only rebuilds the list for states where a match
was found.

### Benchmark results

<details>

<summary> Benchmark script - AI </summary>

```python
"""Micro-benchmark: stop_replicas() pop-all vs selective-remove.

Measures latency and peak memory when stopping a small fraction of replicas
from a ReplicaStateContainer, comparing the old approach (pop all + re-add)
against the new approach (selective remove by ID).

Usage:
    python bench_stop_replicas.py
"""

import gc
import math
import random
import statistics
import time
import tracemalloc
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set


# ---------------------------------------------------------------------------
# Lightweight stubs – just enough to exercise ReplicaStateContainer logic
# without importing all of Ray Serve.
# ---------------------------------------------------------------------------


@DataClass(frozen=True)
class DeploymentID:
    name: str
    app_name: str = "default"

    def __hash__(self):
        return hash((self.name, self.app_name))


@DataClass(frozen=True)
class ReplicaID:
    unique_id: str
    deployment_id: DeploymentID

    def __hash__(self):
        return hash((self.unique_id, self.deployment_id))


class ReplicaState(str, Enum):
    STARTING = "STARTING"
    UPDATING = "UPDATING"
    RECOVERING = "RECOVERING"
    RUNNING = "RUNNING"
    STOPPING = "STOPPING"
    PENDING_MIGRATION = "PENDING_MIGRATION"


ALL_REPLICA_STATES = list(ReplicaState)


class _StubActorDetails:
    """Minimal stand-in for ReplicaDetails (pydantic model in production).

    Deliberately uses __slots__ to keep memory footprint realistic but light.
    """

    __slots__ = ("state", "replica_id", "node_id")

    def __init__(self, state: ReplicaState, replica_id: str):
        self.state = state
        self.replica_id = replica_id
        self.node_id = "node-0"

    def dict(self):
        return {
            "state": self.state,
            "replica_id": self.replica_id,
            "node_id": self.node_id,
        }


class FakeReplica:
    """Minimal stand-in for DeploymentReplica.

    Implements only the surface area touched by the container and stop_replicas.
    """

    def __init__(self, replica_id: ReplicaID, state: ReplicaState):
        self._replica_id = replica_id
        self._actor_details = _StubActorDetails(state, replica_id.unique_id)
        self._update_state_calls = 0

    @Property
    def replica_id(self) -> ReplicaID:
        return self._replica_id

    @Property
    def actor_details(self):
        return self._actor_details

    def update_state(self, state: ReplicaState) -> None:
        """Mirrors update_actor_details; rebuilds the details object."""
        self._actor_details = _StubActorDetails(state, self._replica_id.unique_id)
        self._update_state_calls += 1


# ---------------------------------------------------------------------------
# Container implementations
# ---------------------------------------------------------------------------


class _ContainerBase:
    """Shared helpers."""

    def __init__(self):
        self._replicas: Dict[ReplicaState, List[FakeReplica]] = defaultdict(list)

    def add(self, state: ReplicaState, replica: FakeReplica):
        replica.update_state(state)
        self._replicas[state].append(replica)

    def pop(
        self,
        exclude_version=None,
        states=None,
        max_replicas=math.inf,
    ) -> List[FakeReplica]:
        if states is None:
            states = ALL_REPLICA_STATES
        replicas = []
        for state in states:
            popped = []
            remaining = []
            for replica in self._replicas[state]:
                if len(replicas) + len(popped) == max_replicas:
                    remaining.append(replica)
                else:
                    popped.append(replica)
            self._replicas[state] = remaining
            replicas.extend(popped)
        return replicas


class OldContainer(_ContainerBase):
    """Original stop_replicas: pop everything, re-add non-matching."""

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        stopped = []
        for replica in self.pop():
            if replica.replica_id in replicas_to_stop:
                # In production this calls _stop_replica(); we just record it.
                stopped.append(replica)
            else:
                self.add(replica.actor_details.state, replica)
        return stopped


class NewContainer(_ContainerBase):
    """New stop_replicas: single-pass remove_many by ID set."""

    def remove(self, replica_ids) -> List[FakeReplica]:
        replica_ids = set(replica_ids)
        removed = []
        remaining_to_find = len(replica_ids)
        for state in ALL_REPLICA_STATES:
            if remaining_to_find == 0:
                break
            found_any = False
            remaining = []
            for replica in self._replicas[state]:
                if remaining_to_find > 0 and replica.replica_id in replica_ids:
                    removed.append(replica)
                    remaining_to_find -= 1
                    found_any = True
                else:
                    remaining.append(replica)
            if found_any:
                self._replicas[state] = remaining
        return removed

    def stop_replicas(self, replicas_to_stop: Set[ReplicaID]):
        return self.remove(replicas_to_stop)


# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------

DEPLOYMENT_ID = DeploymentID("bench-deploy", "bench-app")
WARMUP_ROUNDS = 3
MEASURE_ROUNDS = 20


def _make_replicas(n: int) -> List[FakeReplica]:
    return [
        FakeReplica(
            ReplicaID(f"r-{i}", DEPLOYMENT_ID),
            ReplicaState.RUNNING,
        )
        for i in range(n)
    ]


def _fill_container(container, replicas: List[FakeReplica]):
    for r in replicas:
        container.add(ReplicaState.RUNNING, r)


def _pick_targets(replicas: List[FakeReplica], k: int) -> Set[ReplicaID]:
    chosen = random.sample(replicas, k)
    return {r.replica_id for r in chosen}


def bench_latency(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with median, p99, mean latency in microseconds."""
    replicas_master = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_master, num_to_stop)

    timings = []

    for rnd in range(WARMUP_ROUNDS + MEASURE_ROUNDS):
        c = container_cls()
        _fill_container(c, replicas_master)

        gc.disable()
        t0 = time.perf_counter_ns()
        c.stop_replicas(targets)
        t1 = time.perf_counter_ns()
        gc.enable()

        if rnd >= WARMUP_ROUNDS:
            timings.append((t1 - t0) / 1_000)  # ns -> µs

    timings.sort()
    return {
        "median_us": statistics.median(timings),
        "p99_us": timings[int(len(timings) * 0.99)],
        "mean_us": statistics.mean(timings),
    }


def bench_memory(container_cls, n_replicas: int, num_to_stop: int) -> dict:
    """Returns dict with peak memory allocation delta in KiB.

    Averages over multiple runs to reduce noise.
    """
    NUM_MEM_RUNS = 5
    deltas = []
    for _ in range(NUM_MEM_RUNS):
        replicas_master = _make_replicas(n_replicas)
        targets = _pick_targets(replicas_master, num_to_stop)

        c = container_cls()
        _fill_container(c, replicas_master)

        gc.collect()
        tracemalloc.start()

        snap_before = tracemalloc.take_snapshot()
        c.stop_replicas(targets)
        snap_after = tracemalloc.take_snapshot()

        tracemalloc.stop()

        stats = snap_after.compare_to(snap_before, "lineno")
        delta_bytes = sum(s.size_diff for s in stats if s.size_diff > 0)
        deltas.append(delta_bytes / 1024)

    return {"peak_delta_kib": statistics.median(deltas)}


def bench_update_state_calls(n_replicas: int, num_to_stop: int) -> dict:
    """Returns update_state call counts for old vs new."""
    replicas_old = _make_replicas(n_replicas)
    replicas_new = _make_replicas(n_replicas)
    targets = _pick_targets(replicas_old, num_to_stop)

    c_old = OldContainer()
    _fill_container(c_old, replicas_old)
    for r in replicas_old:
        r._update_state_calls = 0

    c_new = NewContainer()
    _fill_container(c_new, replicas_new)
    for r in replicas_new:
        r._update_state_calls = 0

    c_old.stop_replicas(targets)
    c_new.stop_replicas(targets)

    old_calls = sum(r._update_state_calls for r in replicas_old)
    new_calls = sum(r._update_state_calls for r in replicas_new)
    return {"old": old_calls, "new": new_calls}


def run_scenario(label: str, replica_counts: List[int], stop_fn):
    """Run a full latency + memory + update_state scenario.

    Args:
        label: human-readable description (e.g. "stopping 2 replicas").
        replica_counts: list of total replica counts to sweep.
        stop_fn: callable(n) -> number of replicas to stop for that n.
    """
    hdr = (
        f"{'Replicas':>8} {'k':>5}"
        f"  │ {'Old µs':>9} {'New µs':>9} {'Speedup':>7}"
        f"  │ {'Old KiB':>8} {'New KiB':>8} {'Saved':>8}"
        f"  │ {'Old upd':>8} {'New upd':>8}"
    )
    sep = "─" * len(hdr)

    print()
    print(f"  {label}")
    print(sep)
    print(hdr)
    print(sep)

    for n in replica_counts:
        k = stop_fn(n)
        old_lat = bench_latency(OldContainer, n, k)
        new_lat = bench_latency(NewContainer, n, k)
        speedup = old_lat["median_us"] / new_lat["median_us"] if new_lat["median_us"] > 0 else float("inf")
        old_mem = bench_memory(OldContainer, n, k)
        new_mem = bench_memory(NewContainer, n, k)
        saved = old_mem["peak_delta_kib"] - new_mem["peak_delta_kib"]
        us = bench_update_state_calls(n, k)

        print(
            f"{n:>8} {k:>5}"
            f"  │ {old_lat['median_us']:>8.1f}µ {new_lat['median_us']:>8.1f}µ {speedup:>6.1f}x"
            f"  │ {old_mem['peak_delta_kib']:>7.1f}  {new_mem['peak_delta_kib']:>7.1f} {saved:>7.1f}"
            f"  │ {us['old']:>8,} {us['new']:>8,}"
        )

    print(sep)


def main():
    replica_counts = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    print()
    print("=" * 72)
    print("  stop_replicas() micro-benchmark")
    print("  Old = pop all + re-add non-matching  (original)")
    print("  New = single-pass remove by ID set   (optimized)")
    print("=" * 72)

    # Scenario 1: stop 2 replicas (typical downscale)
    run_scenario(
        "SCENARIO 1: Downscale — stop 2 out of N  (typical)",
        replica_counts,
        stop_fn=lambda n: 2,
    )

    # Scenario 2: stop all replicas (full teardown)
    run_scenario(
        "SCENARIO 2: Teardown — stop ALL N replicas  (worst case for old per-ID approach)",
        replica_counts,
        stop_fn=lambda n: n,
    )

    # Scenario 3: stop 10% of replicas
    run_scenario(
        "SCENARIO 3: Moderate downscale — stop 10% of N",
        [64, 128, 256, 512, 1024, 2048, 4096],
        stop_fn=lambda n: max(1, n // 10),
    )

    print()


if __name__ == "__main__":
    main()

```

</details>

**Scenario 1 — Downscale: stop 2 out of N (typical)**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB | Old `upd_state` | New `upd_state` |

|----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------|
| 64 | 2 | 53.6 | 8.4 | 6.4x | 5.1 | 1.2 | 3.8 | 62 | 0 |
| 256 | 2 | 211.3 | 62.4 | 3.4x | 17.2 | 2.8 | 14.3 | 254 | 0 |
| 1024 | 2 | 852.2 | 295.2 | 2.9x | 65.7 | 9.3 | 56.3 | 1,022 | 0 |
| 4096 | 2 | 3402.1 | 586.7 | 5.8x | 257.3 | 32.9 | 224.3 | 4,094 | 0 |

**Scenario 2 — Teardown: stop ALL N (no regression)**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB |

|----------|---|--------|--------|---------|---------|---------|-----------|
| 256 | 256 | 98.8 | 84.6 | 1.2x | 1.2 | 0.7 | 0.5 |
| 1024 | 1024 | 402.0 | 340.6 | 1.2x | 1.2 | 0.7 | 0.5 |
| 4096 | 4096 | 1635.7 | 1378.4 | 1.2x | 1.2 | 0.7 | 0.5 |

Memory is flat and nearly identical for both — when stopping everything,
neither approach re-adds anything.

**Scenario 3 — Moderate downscale: stop 10% of N**

| Replicas | k | Old µs | New µs | Speedup | Old KiB | New KiB | Saved
KiB | Old `upd_state` | New `upd_state` |

|----------|---|--------|--------|---------|---------|---------|-----------|-----------------|-----------------|
| 256 | 25 | 205.3 | 75.9 | 2.7x | 15.6 | 2.5 | 13.1 | 231 | 0 |
| 1024 | 102 | 828.2 | 319.5 | 2.6x | 59.2 | 8.3 | 50.9 | 922 | 0 |
| 4096 | 409 | 3301.4 | 1299.3 | 2.5x | 235.0 | 32.9 | 202.1 | 3,687 | 0
|

The old code's memory grows linearly with N because `pop()` allocates a
temporary list of all replicas, then `add()` rebuilds `ReplicaDetails`
for each re-inserted one. The new code only allocates a `remaining` list
for states where a match is found, and never touches non-matching
replicas.


related to ray-project#60680

---------

Signed-off-by: abrar <abrar@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests

2 participants