Skip to content

[Data] Speed up checkpoint filter and reduce memory usage#60294

Merged
richardliaw merged 5 commits into
ray-project:masterfrom
wxwmd:global_ckpt_filter
Apr 16, 2026
Merged

[Data] Speed up checkpoint filter and reduce memory usage#60294
richardliaw merged 5 commits into
ray-project:masterfrom
wxwmd:global_ckpt_filter

Conversation

@wxwmd

@wxwmd wxwmd commented Jan 19, 2026

Copy link
Copy Markdown
Contributor

source code for issue #60200

Current checkpoint:

Image

The current implementation has two issues:

  1. Each ReadTask copies an Arrow-typed checkpoint_id array and then converts it into a Numpy-typed array. This step is very time-consuming(see previous testing) The most time-consuming operation is repeated in every ReadTask.
  2. Each ReadTask holds a copy of the checkpoint_id array, resulting in high memory usage of the cluster.

Improved Checkpoint (Initial design, single actor):

Maintain a global checkpoint_filter actor that holds the checkpoint_ids array; this actor is responsible for filtering all input blocks.

Image

There are two advantages to this approach:

  1. The most time-consuming operation: the conversion from Arrow-typed array to Numpy-typed array is performed only once.
  2. Reduced memory usage: Each read task no longer needs to hold a large array; only the checkpoint_filter actor holds it.

Performance test

test code:

import shutil
from typing import Dict
import os
import time

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import ray
from ray.data.checkpoint import CheckpointConfig

INPUT_PATH="/tmp/ray_test/input/"
OUTPUT_PATH="/tmp/ray_test/output/"
CKPT_PATH="/tmp/ray_test/ckpt/"


class Qwen3ASRPredictor:
    def __init__(self):
        print("download ckpt")

    def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        return batch_input

def setup():
    if os.path.exists(INPUT_PATH):
        shutil.rmtree(INPUT_PATH)
    if os.path.exists(CKPT_PATH):
        shutil.rmtree(CKPT_PATH)
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)

    # generate input data
    if not os.path.exists(INPUT_PATH):
        os.makedirs(INPUT_PATH)
    for i in range(10000):
        ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)]
        df = pd.DataFrame({'id': ids})
        table = pa.Table.from_pandas(df)
        pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet"))

    # generate checkpoint
    if not os.path.exists(CKPT_PATH):
        os.makedirs(CKPT_PATH)
    ids = [str(i) for i in range(0, 80_000_000)]
    df = pd.DataFrame({'id': ids})
    table = pa.Table.from_pandas(df)
    pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet"))



if __name__ == "__main__":
    ray.init()

    setup()

    ctx = ray.data.DataContext.get_current()
    ctx.checkpoint_config = CheckpointConfig(
        id_column="id",
        checkpoint_path=CKPT_PATH,
        delete_checkpoint_on_success=False,
    )

    start_time = time.time()

    input = ray.data.read_parquet(
        INPUT_PATH,
        parallelism=1000,
        # memory=8 * 1024 **3 # set for origin ray to avoid oom
    )

    pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000)

    pred.write_parquet(OUTPUT_PATH)

    end_time = time.time()

    print(f"costs: {end_time - start_time}s")

    # check result
    result_ds = ray.data.read_parquet(OUTPUT_PATH)
    assert result_ds.count() == 20_000_000

node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom)

origin ray:

pip install ray==2.54.0
python test.py

Speedup:

pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py

Test Result

origin: 680s
speedup: 190s
You can see that even the end2end running time of the task has been accelerated by 3.6 times.

Memory

If we delete this row:

memory=8 * 1024 **3 # set for origin ray to avoid oom

original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability.


Updated 20260225 (ActorPool)

As @owenowenisme methoned, if filtering is performed by a single actor, the single actor could be the bottleneck. Therefore, I extended a single Actor into an ActorPool. For more details, please refer to the link. #60294 (comment)

20260223162250

@wxwmd wxwmd requested a review from a team as a code owner January 19, 2026 11:54

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the checkpoint filtering mechanism to use a global actor that holds checkpointed IDs as a NumPy array. This is a significant improvement for memory efficiency and performance by avoiding passing large checkpoint data to each task. The changes to the build and CI scripts seem appropriate for the internal environment.

My review identified a critical issue where the batch-based filtering path (filter_rows_for_batch and its caller) was not updated to align with the new actor-based architecture, which will lead to runtime errors. I also found a medium-severity performance issue due to using ray.get() inside a loop and a minor issue with a leftover debug print statement.

Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/checkpoint/util.py Outdated
cursor[bot]

This comment was marked as outdated.

cursor[bot]

This comment was marked as outdated.

@ray-gardener ray-gardener Bot added data Ray Data-related issues community-contribution Contributed by the community labels Jan 19, 2026
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from b971082 to b2fc954 Compare January 20, 2026 03:33
cursor[bot]

This comment was marked as outdated.

@daiping8

daiping8 commented Jan 20, 2026

Copy link
Copy Markdown
Contributor

Great job! I have a few questions:

  • A single global actor may become a bottleneck.
    All Read-related filtering requests go through the same BatchBasedCheckpointFilter actor. Could this cause filtering requests to queue up and stall the entire read pipeline on checkpoint filtering?
  • checkpointed_ids is fully materialized as a numpy array and kept resident in memory.
    We call combine_chunks and then to_numpy once to build a single large ndarray.
    If the checkpoint is large, the actor process must have enough contiguous memory to hold the entire ID column.

Maybe we could turn the actor into a sharded design (multiple actors, partitioned by hash(id) or by ID ranges), and support a “partial loading + partial filtering” mode instead of materializing the entire ndarray at once.

@wxwmd

wxwmd commented Jan 21, 2026

Copy link
Copy Markdown
Contributor Author

Great job! I have a few questions:

Thanks.

  1. In my test, filtering requests are processed very fast. If checkpoint has 115millon rows, each block has 10k+ rows, each filtering request can be processed in 0.2s. See my log:
image
@wxwmd

wxwmd commented Jan 21, 2026

Copy link
Copy Markdown
Contributor Author

Great job! I have a few questions:

  • A single global actor may become a bottleneck.
    All Read-related filtering requests go through the same BatchBasedCheckpointFilter actor. Could this cause filtering requests to queue up and stall the entire read pipeline on checkpoint filtering?
  • checkpointed_ids is fully materialized as a numpy array and kept resident in memory.
    We call combine_chunks and then to_numpy once to build a single large ndarray.
    If the checkpoint is large, the actor process must have enough contiguous memory to hold the entire ID column.

Maybe we could turn the actor into a sharded design (multiple actors, partitioned by hash(id) or by ID ranges), and support a “partial loading + partial filtering” mode instead of materializing the entire ndarray at once.

  1. I think that is a good idea. I am interested in implementing in future
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from b2fc954 to eb93d53 Compare January 23, 2026 08:43
cursor[bot]

This comment was marked as outdated.

@wxwmd

wxwmd commented Jan 23, 2026

Copy link
Copy Markdown
Contributor Author

@owenowenisme @raulchen please check this when you have time 😊

@wxwmd wxwmd force-pushed the global_ckpt_filter branch from eb93d53 to 0276b22 Compare January 23, 2026 09:16
Comment thread python/ray/data/checkpoint/util.py Outdated
Comment thread python/ray/data/checkpoint/util.py Outdated
Comment thread python/ray/data/checkpoint/load_checkpoint_callback.py Outdated
Comment thread python/ray/data/checkpoint/load_checkpoint_callback.py Outdated

@owenowenisme owenowenisme left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for reviewing this so late, and thanks for the beautiful diagram!

When I was reviewing your global actor approach, I have another idea. The actor introduces a serial bottleneck — every read task has to ship its block to the single actor for filtering and wait for the result back. Without max_concurrency, calls are processed one at a time, which could be a significant throughput regression from the old design where each worker filtered locally in parallel.

Instead, what if we keep the filtering local in each worker but broadcast the checkpoint IDs as a numpy array via the object store?
The approach would be:

  1. Load checkpoint data and convert to a sorted numpy array (your PR already does this in _postprocess_block — nice work on that part!)
  2. Use a remote task to do the heavy conversion, then ray.put() the numpy array into the object store
  3. Pass the ObjectRef to each read task via add_map_task_kwargs_fn (the old mechanism)
  4. Each worker calls ray.get(ref) to get a zero-copy read-only view from the local object store, then does searchsorted locally

This gives us:

  • Parallelism: filtering is parallel across all workers, no bottleneck
  • Memory efficiency: Ray's object store stores one copy per node in shared memory, and all workers on the same node share it via zero-copy
  • Minimize the re-computation of converting arrow blocks into numpy
  • Simplicity: No actor needed
@wxwmd

wxwmd commented Feb 10, 2026

Copy link
Copy Markdown
Contributor Author

Sorry for reviewing this so late, and thanks for the beautiful diagram!

When I was reviewing your global actor approach, I have another idea. The actor introduces a serial bottleneck — every read task has to ship its block to the single actor for filtering and wait for the result back. Without max_concurrency, calls are processed one at a time, which could be a significant throughput regression from the old design where each worker filtered locally in parallel.

Instead, what if we keep the filtering local in each worker but broadcast the checkpoint IDs as a numpy array via the object store? The approach would be:

  1. Load checkpoint data and convert to a sorted numpy array (your PR already does this in _postprocess_block — nice work on that part!)
  2. Use a remote task to do the heavy conversion, then ray.put() the numpy array into the object store
  3. Pass the ObjectRef to each read task via add_map_task_kwargs_fn (the old mechanism)
  4. Each worker calls ray.get(ref) to get a zero-copy read-only view from the local object store, then does searchsorted locally

This gives us:

  • Parallelism: filtering is parallel across all workers, no bottleneck
  • Memory efficiency: Ray's object store stores one copy per node in shared memory, and all workers on the same node share it via zero-copy
  • Minimize the re-computation of converting arrow blocks into numpy
  • Simplicity: No actor needed

hi youcheng, thanks for reviewing!

When I first solve this problem, I had the same idea as you: perform the Arrow->NumPy only once, then broadcast that NumPy array.

I implemented and tested this approach, and it has one issue: the NumPy array is too large.
For example, I have 100 million string IDs; storing them with Arrow takes 2 GB, but with NumPy it takes about ~10 GB. See the demo below:

import sys

import numpy as np

N = 10000_000 # set to 1kw to avoid oom
arr = np.array([f"text_{i}" for i in range(N)])

mem = arr.nbytes + sum(sys.getsizeof(s) for s in arr)

print(f"the 1kw arr costs {mem / 1024**3}GB, array of size 10000_0000 will costs {10 * mem / 1024**3}GB")

having each worker keep a ~10 GB object in memory is unacceptable. Our cluster has about 1,000 nodes, which means roughly 10,000 GB of memory would be used only for checkpoint.
This is also the second issue my diagram aims to address: redundant memory usage.

@owenowenisme

owenowenisme commented Feb 10, 2026

Copy link
Copy Markdown
Member

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck.
Do you have any plan to avoid this?
Also this could affect our back pressure right?

@wxwmd

wxwmd commented Feb 11, 2026

Copy link
Copy Markdown
Contributor Author

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck. Do you have any plan to avoid this? Also this could affect our back pressure right?

Yes, I will use concurrency to enhance the actor in this PR.
As for backpressure: yes, the actor’s speed is the limiting factor, some read tasks will have to wait. However, based on the experiments above, the read task is still faster than it is now.

@wxwmd

wxwmd commented Feb 11, 2026

Copy link
Copy Markdown
Contributor Author

Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck. Do you have any plan to avoid this? Also this could affect our back pressure right?

@owenowenisme what do you think if we implement a checkpoint actor-pool? i think this will solve the single-actor bottleneck

@wxwmd wxwmd force-pushed the global_ckpt_filter branch from e7ed29d to b345fd4 Compare February 23, 2026 07:55
Comment thread python/ray/data/checkpoint/checkpoint_filter.py
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from b345fd4 to d10e010 Compare February 23, 2026 08:11
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from 20a35dc to 067e6df Compare March 24, 2026 10:51
Comment thread python/ray/data/tests/test_checkpoint.py
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from 067e6df to c0a1b86 Compare March 25, 2026 02:04
Comment thread python/ray/data/_internal/planner/planner.py
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from c0a1b86 to e1b7c8e Compare March 25, 2026 02:14
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
@wxwmd wxwmd force-pushed the global_ckpt_filter branch 2 times, most recently from d8b0ea0 to 95ddfd2 Compare March 25, 2026 12:27
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
ckpt_chunks = combined_checkpointed_ids[id_column].chunks

checkpointed_ids_ndarray = []
for ckpt_chunk in ckpt_chunks:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why not just loop the chunks? Is there any benefit to combine block?

ckpt_chunks = combined_checkpointed_ids[id_column].chunks

checkpointed_ids_ndarray = []
for ckpt_chunk in ckpt_chunks:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait maybe we don't need to handle those chunks and concat by ourselves, I think transform_pyarrow.to_numpy can already handle that?

Comment thread python/ray/data/checkpoint/checkpoint_filter.py
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/checkpoint/checkpoint_filter.py Outdated
Comment thread python/ray/data/_internal/planner/checkpoint/plan_read_op.py Outdated
@owenowenisme

Copy link
Copy Markdown
Member

Also currently users cannot know we are loading checkpoint just by looking at the log and they might think pipeline is stalled, we can do something like adding a log before loading starts: logger.info("Loading checkpoint from ...")

Comment thread python/ray/data/tests/test_checkpoint.py Outdated
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from 073b21e to 701f561 Compare April 3, 2026 03:50
Comment thread python/ray/data/checkpoint/checkpoint_filter.py
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from 701f561 to 53b6ffd Compare April 3, 2026 04:00
@wxwmd

wxwmd commented Apr 3, 2026

Copy link
Copy Markdown
Contributor Author

Also currently users cannot know we are loading checkpoint just by looking at the log and they might think pipeline is stalled, we can do something like adding a log before loading starts: logger.info("Loading checkpoint from ...")

done

xiaowen.wxw added 3 commits April 3, 2026 14:12
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>

# Conflicts:
#	python/ray/data/_internal/planner/planner.py
#	python/ray/data/checkpoint/load_checkpoint_callback.py
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
@wxwmd wxwmd force-pushed the global_ckpt_filter branch from 53b6ffd to 40255c8 Compare April 3, 2026 06:18
@wxwmd

wxwmd commented Apr 9, 2026

Copy link
Copy Markdown
Contributor Author

@owenowenisme owenowenisme left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the thorough work on this PR and for iterating on the review feedback.
This is a hella huge improvement to our ray data checkpoint efficiency and stability.

Thanks @wxwmd for the design and driving this!

Comment thread python/ray/data/context.py Outdated
Comment on lines +829 to +834
checkpoint_actor_pool_min_size: int = DEFAULT_CHECKPOINT_ACTOR_POOL_MIN_SIZE

checkpoint_actor_pool_max_size: int = DEFAULT_CHECKPOINT_ACTOR_POOL_MAX_SIZE

checkpoint_actor_memory_bytes: int = DEFAULT_CHECKPOINT_ACTOR_MEMORY_BYTES

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to introduce these variables?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually shouldn't this go into _checkpoint_config?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I previously discussed this with @xinyuangui2 , and we thinked that these variables are low-level — we don't want users to be aware of them in _checkpoint_config. At the same time, we want to keep an entry point for users to specify these variables if needed, so they are moved into DataContext.
#60294 (comment)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put them as private variables into checkpoint_config? they still are exposed to the user anyhow

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Reviewed by Cursor Bugbot for commit 53d39fc. Configure here.

Comment thread python/ray/data/checkpoint/checkpoint_filter.py
@richardliaw richardliaw merged commit 56f7ec1 into ray-project:master Apr 16, 2026
6 checks passed
HLDKNotFound pushed a commit to chichic21039/ray that referenced this pull request Apr 22, 2026
…t#60294)

> source code for issue
https://github.com/issues/created?issue=ray-project%7Cray%7C60200

### Current checkpoint:

<img width="2796" height="1198" alt="Image"
src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647"
/>

The current implementation has two issues:
1. Each ReadTask copies an Arrow-typed checkpoint_id array and then
converts it into a Numpy-typed array. This step is very
time-consuming(see [previous
testing](ray-project#60002)) The most
time-consuming operation is repeated in every ReadTask.
2. Each ReadTask holds a copy of the checkpoint_id array, resulting in
high memory usage of the cluster.


### Improved Checkpoint (Initial design, single actor):
Maintain a global `checkpoint_filter` actor that holds the
`checkpoint_ids` array; this actor is responsible for filtering all
input blocks.

<img width="2096" height="1278" alt="Image"
src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d"
/>

There are two advantages to this approach:
1. The most time-consuming operation: the conversion from Arrow-typed
array to Numpy-typed array is performed only once.
2. Reduced memory usage: Each read task no longer needs to hold a large
array; only the `checkpoint_filter `actor holds it.

### Performance test
test code:
```
import shutil
from typing import Dict
import os
import time

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import ray
from ray.data.checkpoint import CheckpointConfig

INPUT_PATH="/tmp/ray_test/input/"
OUTPUT_PATH="/tmp/ray_test/output/"
CKPT_PATH="/tmp/ray_test/ckpt/"


class Qwen3ASRPredictor:
    def __init__(self):
        print("download ckpt")

    def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        return batch_input

def setup():
    if os.path.exists(INPUT_PATH):
        shutil.rmtree(INPUT_PATH)
    if os.path.exists(CKPT_PATH):
        shutil.rmtree(CKPT_PATH)
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)

    # generate input data
    if not os.path.exists(INPUT_PATH):
        os.makedirs(INPUT_PATH)
    for i in range(10000):
        ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)]
        df = pd.DataFrame({'id': ids})
        table = pa.Table.from_pandas(df)
        pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet"))

    # generate checkpoint
    if not os.path.exists(CKPT_PATH):
        os.makedirs(CKPT_PATH)
    ids = [str(i) for i in range(0, 80_000_000)]
    df = pd.DataFrame({'id': ids})
    table = pa.Table.from_pandas(df)
    pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet"))



if __name__ == "__main__":
    ray.init()

    setup()

    ctx = ray.data.DataContext.get_current()
    ctx.checkpoint_config = CheckpointConfig(
        id_column="id",
        checkpoint_path=CKPT_PATH,
        delete_checkpoint_on_success=False,
    )

    start_time = time.time()

    input = ray.data.read_parquet(
        INPUT_PATH,
        parallelism=1000,
        # memory=8 * 1024 **3 # set for origin ray to avoid oom
    )

    pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000)

    pred.write_parquet(OUTPUT_PATH)

    end_time = time.time()

    print(f"costs: {end_time - start_time}s")

    # check result
    result_ds = ray.data.read_parquet(OUTPUT_PATH)
    assert result_ds.count() == 20_000_000
```

node: 16 cores with 64GB memory (make sure you have memory at least 16GB
to avoid oom)

#### origin ray:
```
pip install ray==2.54.0
python test.py
```
#### Speedup:
```
pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py
```

#### Test Result
origin: 680s
speedup: 190s
You can see that even the end2end running time of the task has been
accelerated by 3.6 times.

#### Memory
If we delete this row:
```
memory=8 * 1024 **3 # set for origin ray to avoid oom
```
original ray will oom, the fixed ray passed. This demonstrates that this
PR has enhanced the stability.


----
### Updated 20260225 (ActorPool)
As @owenowenisme methoned, if filtering is performed by a single actor,
the single actor could be the bottleneck. Therefore, I extended a single
Actor into an ActorPool. For more details, please refer to the link.
ray-project#60294 (comment)


![20260223162250](https://github.com/user-attachments/assets/9bd1067f-f2a8-47dd-8f99-e232be64155e)

---------

Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
…t#60294)

> source code for issue
https://github.com/issues/created?issue=ray-project%7Cray%7C60200

### Current checkpoint:

<img width="2796" height="1198" alt="Image"
src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647"
/>

The current implementation has two issues:
1. Each ReadTask copies an Arrow-typed checkpoint_id array and then
converts it into a Numpy-typed array. This step is very
time-consuming(see [previous
testing](ray-project#60002)) The most
time-consuming operation is repeated in every ReadTask.
2. Each ReadTask holds a copy of the checkpoint_id array, resulting in
high memory usage of the cluster.


### Improved Checkpoint (Initial design, single actor):
Maintain a global `checkpoint_filter` actor that holds the
`checkpoint_ids` array; this actor is responsible for filtering all
input blocks.

<img width="2096" height="1278" alt="Image"
src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d"
/>

There are two advantages to this approach:
1. The most time-consuming operation: the conversion from Arrow-typed
array to Numpy-typed array is performed only once.
2. Reduced memory usage: Each read task no longer needs to hold a large
array; only the `checkpoint_filter `actor holds it.

### Performance test
test code:
```
import shutil
from typing import Dict
import os
import time

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import ray
from ray.data.checkpoint import CheckpointConfig

INPUT_PATH="/tmp/ray_test/input/"
OUTPUT_PATH="/tmp/ray_test/output/"
CKPT_PATH="/tmp/ray_test/ckpt/"


class Qwen3ASRPredictor:
    def __init__(self):
        print("download ckpt")

    def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        return batch_input

def setup():
    if os.path.exists(INPUT_PATH):
        shutil.rmtree(INPUT_PATH)
    if os.path.exists(CKPT_PATH):
        shutil.rmtree(CKPT_PATH)
    if os.path.exists(OUTPUT_PATH):
        shutil.rmtree(OUTPUT_PATH)

    # generate input data
    if not os.path.exists(INPUT_PATH):
        os.makedirs(INPUT_PATH)
    for i in range(10000):
        ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)]
        df = pd.DataFrame({'id': ids})
        table = pa.Table.from_pandas(df)
        pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet"))

    # generate checkpoint
    if not os.path.exists(CKPT_PATH):
        os.makedirs(CKPT_PATH)
    ids = [str(i) for i in range(0, 80_000_000)]
    df = pd.DataFrame({'id': ids})
    table = pa.Table.from_pandas(df)
    pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet"))



if __name__ == "__main__":
    ray.init()

    setup()

    ctx = ray.data.DataContext.get_current()
    ctx.checkpoint_config = CheckpointConfig(
        id_column="id",
        checkpoint_path=CKPT_PATH,
        delete_checkpoint_on_success=False,
    )

    start_time = time.time()

    input = ray.data.read_parquet(
        INPUT_PATH,
        parallelism=1000,
        # memory=8 * 1024 **3 # set for origin ray to avoid oom
    )

    pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000)

    pred.write_parquet(OUTPUT_PATH)

    end_time = time.time()

    print(f"costs: {end_time - start_time}s")

    # check result
    result_ds = ray.data.read_parquet(OUTPUT_PATH)
    assert result_ds.count() == 20_000_000
```

node: 16 cores with 64GB memory (make sure you have memory at least 16GB
to avoid oom)

#### origin ray:
```
pip install ray==2.54.0
python test.py
```
#### Speedup:
```
pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl
python test.py
```

#### Test Result
origin: 680s
speedup: 190s
You can see that even the end2end running time of the task has been
accelerated by 3.6 times.

#### Memory
If we delete this row:
```
memory=8 * 1024 **3 # set for origin ray to avoid oom
```
original ray will oom, the fixed ray passed. This demonstrates that this
PR has enhanced the stability.


----
### Updated 20260225 (ActorPool)
As @owenowenisme methoned, if filtering is performed by a single actor,
the single actor could be the bottleneck. Therefore, I extended a single
Actor into an ActorPool. For more details, please refer to the link.
ray-project#60294 (comment)


![20260223162250](https://github.com/user-attachments/assets/9bd1067f-f2a8-47dd-8f99-e232be64155e)

---------

Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues go add ONLY when ready to merge, run all tests

7 participants