[Data] Speed up checkpoint filter and reduce memory usage#60294
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the checkpoint filtering mechanism to use a global actor that holds checkpointed IDs as a NumPy array. This is a significant improvement for memory efficiency and performance by avoiding passing large checkpoint data to each task. The changes to the build and CI scripts seem appropriate for the internal environment.
My review identified a critical issue where the batch-based filtering path (filter_rows_for_batch and its caller) was not updated to align with the new actor-based architecture, which will lead to runtime errors. I also found a medium-severity performance issue due to using ray.get() inside a loop and a minor issue with a leftover debug print statement.
a71364b to
b971082
Compare
b971082 to
b2fc954
Compare
|
Great job! I have a few questions:
Maybe we could turn the actor into a sharded design (multiple actors, partitioned by |
|
b2fc954 to
eb93d53
Compare
|
@owenowenisme @raulchen please check this when you have time 😊 |
eb93d53 to
0276b22
Compare
0276b22 to
df3e7aa
Compare
There was a problem hiding this comment.
Sorry for reviewing this so late, and thanks for the beautiful diagram!
When I was reviewing your global actor approach, I have another idea. The actor introduces a serial bottleneck — every read task has to ship its block to the single actor for filtering and wait for the result back. Without max_concurrency, calls are processed one at a time, which could be a significant throughput regression from the old design where each worker filtered locally in parallel.
Instead, what if we keep the filtering local in each worker but broadcast the checkpoint IDs as a numpy array via the object store?
The approach would be:
- Load checkpoint data and convert to a sorted numpy array (your PR already does this in _postprocess_block — nice work on that part!)
- Use a remote task to do the heavy conversion, then
ray.put()the numpy array into the object store - Pass the ObjectRef to each read task via
add_map_task_kwargs_fn(the old mechanism) - Each worker calls
ray.get(ref)to get a zero-copy read-only view from the local object store, then does searchsorted locally
This gives us:
- Parallelism: filtering is parallel across all workers, no bottleneck
- Memory efficiency: Ray's object store stores one copy per node in shared memory, and all workers on the same node share it via zero-copy
- Minimize the re-computation of converting arrow blocks into numpy
- Simplicity: No actor needed
hi youcheng, thanks for reviewing! When I first solve this problem, I had the same idea as you: perform the I implemented and tested this approach, and it has one issue: the NumPy array is too large. having each worker keep a ~10 GB object in memory is unacceptable. Our cluster has about 1,000 nodes, which means roughly 10,000 GB of memory would be used only for checkpoint. |
|
Got it, I think this is valid, one problem is that we should avoid this actor becoming bottleneck. |
Yes, I will use concurrency to enhance the actor in this PR. |
@owenowenisme what do you think if we implement a checkpoint actor-pool? i think this will solve the single-actor bottleneck |
e7ed29d to
b345fd4
Compare
b345fd4 to
d10e010
Compare
20a35dc to
067e6df
Compare
067e6df to
c0a1b86
Compare
c0a1b86 to
e1b7c8e
Compare
d8b0ea0 to
95ddfd2
Compare
| ckpt_chunks = combined_checkpointed_ids[id_column].chunks | ||
|
|
||
| checkpointed_ids_ndarray = [] | ||
| for ckpt_chunk in ckpt_chunks: |
There was a problem hiding this comment.
Then why not just loop the chunks? Is there any benefit to combine block?
| ckpt_chunks = combined_checkpointed_ids[id_column].chunks | ||
|
|
||
| checkpointed_ids_ndarray = [] | ||
| for ckpt_chunk in ckpt_chunks: |
There was a problem hiding this comment.
Wait maybe we don't need to handle those chunks and concat by ourselves, I think transform_pyarrow.to_numpy can already handle that?
|
Also currently users cannot know we are loading checkpoint just by looking at the log and they might think pipeline is stalled, we can do something like adding a log before loading starts: |
073b21e to
701f561
Compare
701f561 to
53b6ffd
Compare
done |
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com> # Conflicts: # python/ray/data/_internal/planner/planner.py # python/ray/data/checkpoint/load_checkpoint_callback.py
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
53b6ffd to
40255c8
Compare
|
@owenowenisme PTAL |
There was a problem hiding this comment.
LGTM! Thanks for the thorough work on this PR and for iterating on the review feedback.
This is a hella huge improvement to our ray data checkpoint efficiency and stability.
Thanks @wxwmd for the design and driving this!
| checkpoint_actor_pool_min_size: int = DEFAULT_CHECKPOINT_ACTOR_POOL_MIN_SIZE | ||
|
|
||
| checkpoint_actor_pool_max_size: int = DEFAULT_CHECKPOINT_ACTOR_POOL_MAX_SIZE | ||
|
|
||
| checkpoint_actor_memory_bytes: int = DEFAULT_CHECKPOINT_ACTOR_MEMORY_BYTES | ||
|
|
There was a problem hiding this comment.
Do we actually need to introduce these variables?
There was a problem hiding this comment.
actually shouldn't this go into _checkpoint_config?
There was a problem hiding this comment.
Hi, I previously discussed this with @xinyuangui2 , and we thinked that these variables are low-level — we don't want users to be aware of them in _checkpoint_config. At the same time, we want to keep an entry point for users to specify these variables if needed, so they are moved into DataContext.
#60294 (comment)
There was a problem hiding this comment.
can you put them as private variables into checkpoint_config? they still are exposed to the user anyhow
Signed-off-by: xiaowen.wxw <wxw403883@alibaba-inc.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Reviewed by Cursor Bugbot for commit 53d39fc. Configure here.
…t#60294) > source code for issue https://github.com/issues/created?issue=ray-project%7Cray%7C60200 ### Current checkpoint: <img width="2796" height="1198" alt="Image" src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647" /> The current implementation has two issues: 1. Each ReadTask copies an Arrow-typed checkpoint_id array and then converts it into a Numpy-typed array. This step is very time-consuming(see [previous testing](ray-project#60002)) The most time-consuming operation is repeated in every ReadTask. 2. Each ReadTask holds a copy of the checkpoint_id array, resulting in high memory usage of the cluster. ### Improved Checkpoint (Initial design, single actor): Maintain a global `checkpoint_filter` actor that holds the `checkpoint_ids` array; this actor is responsible for filtering all input blocks. <img width="2096" height="1278" alt="Image" src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d" /> There are two advantages to this approach: 1. The most time-consuming operation: the conversion from Arrow-typed array to Numpy-typed array is performed only once. 2. Reduced memory usage: Each read task no longer needs to hold a large array; only the `checkpoint_filter `actor holds it. ### Performance test test code: ``` import shutil from typing import Dict import os import time import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import ray from ray.data.checkpoint import CheckpointConfig INPUT_PATH="/tmp/ray_test/input/" OUTPUT_PATH="/tmp/ray_test/output/" CKPT_PATH="/tmp/ray_test/ckpt/" class Qwen3ASRPredictor: def __init__(self): print("download ckpt") def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return batch_input def setup(): if os.path.exists(INPUT_PATH): shutil.rmtree(INPUT_PATH) if os.path.exists(CKPT_PATH): shutil.rmtree(CKPT_PATH) if os.path.exists(OUTPUT_PATH): shutil.rmtree(OUTPUT_PATH) # generate input data if not os.path.exists(INPUT_PATH): os.makedirs(INPUT_PATH) for i in range(10000): ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet")) # generate checkpoint if not os.path.exists(CKPT_PATH): os.makedirs(CKPT_PATH) ids = [str(i) for i in range(0, 80_000_000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet")) if __name__ == "__main__": ray.init() setup() ctx = ray.data.DataContext.get_current() ctx.checkpoint_config = CheckpointConfig( id_column="id", checkpoint_path=CKPT_PATH, delete_checkpoint_on_success=False, ) start_time = time.time() input = ray.data.read_parquet( INPUT_PATH, parallelism=1000, # memory=8 * 1024 **3 # set for origin ray to avoid oom ) pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000) pred.write_parquet(OUTPUT_PATH) end_time = time.time() print(f"costs: {end_time - start_time}s") # check result result_ds = ray.data.read_parquet(OUTPUT_PATH) assert result_ds.count() == 20_000_000 ``` node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom) #### origin ray: ``` pip install ray==2.54.0 python test.py ``` #### Speedup: ``` pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl python test.py ``` #### Test Result origin: 680s speedup: 190s You can see that even the end2end running time of the task has been accelerated by 3.6 times. #### Memory If we delete this row: ``` memory=8 * 1024 **3 # set for origin ray to avoid oom ``` original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability. ---- ### Updated 20260225 (ActorPool) As @owenowenisme methoned, if filtering is performed by a single actor, the single actor could be the bottleneck. Therefore, I extended a single Actor into an ActorPool. For more details, please refer to the link. ray-project#60294 (comment)  --------- Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com> Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
…t#60294) > source code for issue https://github.com/issues/created?issue=ray-project%7Cray%7C60200 ### Current checkpoint: <img width="2796" height="1198" alt="Image" src="https://github.com/user-attachments/assets/528ad72b-6975-4e96-8f01-39e373990647" /> The current implementation has two issues: 1. Each ReadTask copies an Arrow-typed checkpoint_id array and then converts it into a Numpy-typed array. This step is very time-consuming(see [previous testing](ray-project#60002)) The most time-consuming operation is repeated in every ReadTask. 2. Each ReadTask holds a copy of the checkpoint_id array, resulting in high memory usage of the cluster. ### Improved Checkpoint (Initial design, single actor): Maintain a global `checkpoint_filter` actor that holds the `checkpoint_ids` array; this actor is responsible for filtering all input blocks. <img width="2096" height="1278" alt="Image" src="https://github.com/user-attachments/assets/b9956eff-c807-45c4-bc4c-f0497974370d" /> There are two advantages to this approach: 1. The most time-consuming operation: the conversion from Arrow-typed array to Numpy-typed array is performed only once. 2. Reduced memory usage: Each read task no longer needs to hold a large array; only the `checkpoint_filter `actor holds it. ### Performance test test code: ``` import shutil from typing import Dict import os import time import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import ray from ray.data.checkpoint import CheckpointConfig INPUT_PATH="/tmp/ray_test/input/" OUTPUT_PATH="/tmp/ray_test/output/" CKPT_PATH="/tmp/ray_test/ckpt/" class Qwen3ASRPredictor: def __init__(self): print("download ckpt") def __call__(self, batch_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return batch_input def setup(): if os.path.exists(INPUT_PATH): shutil.rmtree(INPUT_PATH) if os.path.exists(CKPT_PATH): shutil.rmtree(CKPT_PATH) if os.path.exists(OUTPUT_PATH): shutil.rmtree(OUTPUT_PATH) # generate input data if not os.path.exists(INPUT_PATH): os.makedirs(INPUT_PATH) for i in range(10000): ids = [str(i) for i in range(i * 10000, (i + 1) * 10000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(INPUT_PATH, f"{i}.parquet")) # generate checkpoint if not os.path.exists(CKPT_PATH): os.makedirs(CKPT_PATH) ids = [str(i) for i in range(0, 80_000_000)] df = pd.DataFrame({'id': ids}) table = pa.Table.from_pandas(df) pq.write_table(table, os.path.join(CKPT_PATH, "ckpt.parquet")) if __name__ == "__main__": ray.init() setup() ctx = ray.data.DataContext.get_current() ctx.checkpoint_config = CheckpointConfig( id_column="id", checkpoint_path=CKPT_PATH, delete_checkpoint_on_success=False, ) start_time = time.time() input = ray.data.read_parquet( INPUT_PATH, parallelism=1000, # memory=8 * 1024 **3 # set for origin ray to avoid oom ) pred = input.map_batches(Qwen3ASRPredictor, batch_size=1000) pred.write_parquet(OUTPUT_PATH) end_time = time.time() print(f"costs: {end_time - start_time}s") # check result result_ds = ray.data.read_parquet(OUTPUT_PATH) assert result_ds.count() == 20_000_000 ``` node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom) #### origin ray: ``` pip install ray==2.54.0 python test.py ``` #### Speedup: ``` pip install https://ray-wheel.oss-cn-beijing.aliyuncs.com/speedup/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl python test.py ``` #### Test Result origin: 680s speedup: 190s You can see that even the end2end running time of the task has been accelerated by 3.6 times. #### Memory If we delete this row: ``` memory=8 * 1024 **3 # set for origin ray to avoid oom ``` original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability. ---- ### Updated 20260225 (ActorPool) As @owenowenisme methoned, if filtering is performed by a single actor, the single actor could be the bottleneck. Therefore, I extended a single Actor into an ActorPool. For more details, please refer to the link. ray-project#60294 (comment)  --------- Co-authored-by: xiaowen.wxw <wxw403883@alibaba-inc.com> Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>


Current checkpoint:
The current implementation has two issues:
Improved Checkpoint (Initial design, single actor):
Maintain a global
checkpoint_filteractor that holds thecheckpoint_idsarray; this actor is responsible for filtering all input blocks.There are two advantages to this approach:
checkpoint_filteractor holds it.Performance test
test code:
node: 16 cores with 64GB memory (make sure you have memory at least 16GB to avoid oom)
origin ray:
Speedup:
Test Result
origin: 680s
speedup: 190s
You can see that even the end2end running time of the task has been accelerated by 3.6 times.
Memory
If we delete this row:
original ray will oom, the fixed ray passed. This demonstrates that this PR has enhanced the stability.
Updated 20260225 (ActorPool)
As @owenowenisme methoned, if filtering is performed by a single actor, the single actor could be the bottleneck. Therefore, I extended a single Actor into an ActorPool. For more details, please refer to the link. #60294 (comment)