[Data] boost hash_partition w/ sort_indices + zero-copy slices#63498
Merged
goutamvenkat-anyscale merged 6 commits intoJun 4, 2026
Merged
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request optimizes the hash_partition function in python/ray/data/_internal/arrow_ops/transform_pyarrow.py by replacing multiple independent take operations with a single take followed by zero-copy slices. The implementation uses pyarrow.compute.sort_indices to sort the table by partition ID and numpy.bincount to determine partition boundaries. I have no feedback to provide as there were no review comments to evaluate.
goutamvenkat-anyscale
approved these changes
Jun 4, 2026
rueian
pushed a commit
to rueian/ray
that referenced
this pull request
Jun 4, 2026
…roject#63498) ## Description `hash_partition` previously did three expensive things in sequence: N = `num_partition` | R = `num_rows` 1. Built per-partition index arrays via `N × np.where(part_ids == p)` — O(N · R) scans 2. Defragmented the input via `try_combine_chunked_columns(table)` — a full-table copy 3. Ran `N` independent `table.take(indices[p])` calls This change replaces all three with: 1. `pyarrow.compute.sort_indices(partition_ids)` — radix sort on integers, one O(R) pass 2. One `take_table(table, sort_indices)` on the original (possibly chunked) input 3. `N` zero-copy `Table.slice()` calls The `N` takes together form a permutation of the table, so consolidating them into one sort + N zero-copy slices is equivalent and strictly cheaper (fixed take overhead paid once instead of N times). The defrag copy can also be removed: the original Arrow problem (apache/arrow#35126) is that every `take` on a chunked table internally concatenates all chunks first, so `try_combine_chunked_columns` exists to pay that concat once externally and let the subsequent N takes use the fast path. By calling `take` only once, the internal concat happens just once anyway — the external defrag becomes redundant. And because the take output already arranges each partition's rows contiguously, we can carve out the N partitions with zero-copy slices instead of materializing a second copy — which would be another 1GB for a 1GB input. ## Benchmark: 1GB block → 1000 partitions Single thread, PyArrow 23.0.1. `K` = number of chunks in the input table; `K=256` mirrors realistic multi-chunk input. | Block shape | K | Time before | Time after | Speedup | Peak Arrow before | Peak Arrow after | |---|---|---|---|---|---|---| | 16M rows × 8 int64 | 1 | 6542 ms | **939 ms** | **6.96×** | 1024 MB (no copy needed) | 1152 MB | | 16M rows × 8 int64 | 256 | 6887 ms | **1057 ms** | **6.51×** | 2048 MB | **1280 MB** | | 8M rows × 16 int64 | 1 | 4818 ms | **825 ms** | **5.84×** | 1024 MB (no copy needed) | 1088 MB | | 8M rows × 16 int64 | 256 | 5086 ms | **982 ms** | **5.18×** | 2048 MB | **1152 MB** | | 2M rows × 64 int64 | 1 | 2468 ms | **322 ms** | **7.66×** | 1026 MB (no copy needed) | 1040 MB | | 2M rows × 64 int64 | 256 | 2216 ms | **369 ms** | **6.01×** | 2050 MB | **1056 MB** | - **Throughput**: 5–8× faster across all shapes. - **Peak Arrow allocation** on chunked inputs (K=256): ~2.0 GB → ~1.1 GB (~40% reduction) — the input no longer has to coexist with a defragmented copy. | | Before | After | Speedup | |---|---|---|---| | `aggregate_groups` (84 groups, mean) | 61 s | **40 s** | **1.53×** | ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
edoakes
pushed a commit
to edoakes/ray
that referenced
this pull request
Jun 5, 2026
…roject#63498) ## Description `hash_partition` previously did three expensive things in sequence: N = `num_partition` | R = `num_rows` 1. Built per-partition index arrays via `N × np.where(part_ids == p)` — O(N · R) scans 2. Defragmented the input via `try_combine_chunked_columns(table)` — a full-table copy 3. Ran `N` independent `table.take(indices[p])` calls This change replaces all three with: 1. `pyarrow.compute.sort_indices(partition_ids)` — radix sort on integers, one O(R) pass 2. One `take_table(table, sort_indices)` on the original (possibly chunked) input 3. `N` zero-copy `Table.slice()` calls The `N` takes together form a permutation of the table, so consolidating them into one sort + N zero-copy slices is equivalent and strictly cheaper (fixed take overhead paid once instead of N times). The defrag copy can also be removed: the original Arrow problem (apache/arrow#35126) is that every `take` on a chunked table internally concatenates all chunks first, so `try_combine_chunked_columns` exists to pay that concat once externally and let the subsequent N takes use the fast path. By calling `take` only once, the internal concat happens just once anyway — the external defrag becomes redundant. And because the take output already arranges each partition's rows contiguously, we can carve out the N partitions with zero-copy slices instead of materializing a second copy — which would be another 1GB for a 1GB input. ## Benchmark: 1GB block → 1000 partitions Single thread, PyArrow 23.0.1. `K` = number of chunks in the input table; `K=256` mirrors realistic multi-chunk input. | Block shape | K | Time before | Time after | Speedup | Peak Arrow before | Peak Arrow after | |---|---|---|---|---|---|---| | 16M rows × 8 int64 | 1 | 6542 ms | **939 ms** | **6.96×** | 1024 MB (no copy needed) | 1152 MB | | 16M rows × 8 int64 | 256 | 6887 ms | **1057 ms** | **6.51×** | 2048 MB | **1280 MB** | | 8M rows × 16 int64 | 1 | 4818 ms | **825 ms** | **5.84×** | 1024 MB (no copy needed) | 1088 MB | | 8M rows × 16 int64 | 256 | 5086 ms | **982 ms** | **5.18×** | 2048 MB | **1152 MB** | | 2M rows × 64 int64 | 1 | 2468 ms | **322 ms** | **7.66×** | 1026 MB (no copy needed) | 1040 MB | | 2M rows × 64 int64 | 256 | 2216 ms | **369 ms** | **6.01×** | 2050 MB | **1056 MB** | - **Throughput**: 5–8× faster across all shapes. - **Peak Arrow allocation** on chunked inputs (K=256): ~2.0 GB → ~1.1 GB (~40% reduction) — the input no longer has to coexist with a defragmented copy. | | Before | After | Speedup | |---|---|---|---| | `aggregate_groups` (84 groups, mean) | 61 s | **40 s** | **1.53×** | ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
limarkdcunha
pushed a commit
to limarkdcunha/ray
that referenced
this pull request
Jun 30, 2026
…roject#63498) ## Description `hash_partition` previously did three expensive things in sequence: N = `num_partition` | R = `num_rows` 1. Built per-partition index arrays via `N × np.where(part_ids == p)` — O(N · R) scans 2. Defragmented the input via `try_combine_chunked_columns(table)` — a full-table copy 3. Ran `N` independent `table.take(indices[p])` calls This change replaces all three with: 1. `pyarrow.compute.sort_indices(partition_ids)` — radix sort on integers, one O(R) pass 2. One `take_table(table, sort_indices)` on the original (possibly chunked) input 3. `N` zero-copy `Table.slice()` calls The `N` takes together form a permutation of the table, so consolidating them into one sort + N zero-copy slices is equivalent and strictly cheaper (fixed take overhead paid once instead of N times). The defrag copy can also be removed: the original Arrow problem (apache/arrow#35126) is that every `take` on a chunked table internally concatenates all chunks first, so `try_combine_chunked_columns` exists to pay that concat once externally and let the subsequent N takes use the fast path. By calling `take` only once, the internal concat happens just once anyway — the external defrag becomes redundant. And because the take output already arranges each partition's rows contiguously, we can carve out the N partitions with zero-copy slices instead of materializing a second copy — which would be another 1GB for a 1GB input. ## Benchmark: 1GB block → 1000 partitions Single thread, PyArrow 23.0.1. `K` = number of chunks in the input table; `K=256` mirrors realistic multi-chunk input. | Block shape | K | Time before | Time after | Speedup | Peak Arrow before | Peak Arrow after | |---|---|---|---|---|---|---| | 16M rows × 8 int64 | 1 | 6542 ms | **939 ms** | **6.96×** | 1024 MB (no copy needed) | 1152 MB | | 16M rows × 8 int64 | 256 | 6887 ms | **1057 ms** | **6.51×** | 2048 MB | **1280 MB** | | 8M rows × 16 int64 | 1 | 4818 ms | **825 ms** | **5.84×** | 1024 MB (no copy needed) | 1088 MB | | 8M rows × 16 int64 | 256 | 5086 ms | **982 ms** | **5.18×** | 2048 MB | **1152 MB** | | 2M rows × 64 int64 | 1 | 2468 ms | **322 ms** | **7.66×** | 1026 MB (no copy needed) | 1040 MB | | 2M rows × 64 int64 | 256 | 2216 ms | **369 ms** | **6.01×** | 2050 MB | **1056 MB** | - **Throughput**: 5–8× faster across all shapes. - **Peak Arrow allocation** on chunked inputs (K=256): ~2.0 GB → ~1.1 GB (~40% reduction) — the input no longer has to coexist with a defragmented copy. | | Before | After | Speedup | |---|---|---|---| | `aggregate_groups` (84 groups, mean) | 61 s | **40 s** | **1.53×** | ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
hash_partitionpreviously did three expensive things in sequence:N =
num_partition| R =num_rowsN × np.where(part_ids == p)— O(N · R) scanstry_combine_chunked_columns(table)— a full-table copyNindependenttable.take(indices[p])callsThis change replaces all three with:
pyarrow.compute.sort_indices(partition_ids)— radix sort on integers, one O(R) passtake_table(table, sort_indices)on the original (possibly chunked) inputNzero-copyTable.slice()callsThe
Ntakes together form a permutation of the table, so consolidating them into one sort + N zero-copy slices is equivalent and strictly cheaper (fixed take overhead paid once instead of N times). The defrag copy can also be removed: the original Arrow problem (apache/arrow#35126) is that everytakeon a chunked table internally concatenates all chunks first, sotry_combine_chunked_columnsexists to pay that concat once externally and let the subsequent N takes use the fast path. By callingtakeonly once, the internal concat happens just once anyway — the external defrag becomes redundant. And because the take output already arranges each partition's rows contiguously, we can carve out the N partitions with zero-copy slices instead of materializing a second copy — which would be another 1GB for a 1GB input.Benchmark: 1GB block → 1000 partitions
Single thread, PyArrow 23.0.1.
K= number of chunks in the input table;K=256mirrors realistic multi-chunk input.aggregate_groups(84 groups, mean)Related issues
Additional information