[Data] Get rid of generators to avoid intermediate state pinning#60598
Conversation
There was a problem hiding this comment.
Code Review
This pull request is a solid piece of engineering that refactors the data processing pipeline to replace generator functions with iterator classes. This is a crucial change to prevent potential memory leaks caused by chained generators holding references to intermediate data. The changes are applied consistently across block_batching and map_transformer components. New iterator classes like _BatchingIterator, ShapeBlocksIterator, and _TransformingBatchIterator are introduced to encapsulate the iteration logic previously found in generators. A new test, test_chained_transforms_release_intermediates_between_batches, is added to verify that intermediate object references are correctly released, which is an excellent addition. The overall change is well-executed and improves memory management in Ray Data's critical path.
| res = [batch] | ||
| out_batch = next(self._cur_output_iter) | ||
| except StopIteration: | ||
| pass |
There was a problem hiding this comment.
For improved clarity and robustness, it's better to explicitly reset self._cur_output_iter to None and continue the loop when the iterator is exhausted. This makes the state transition explicit and avoids relying on the iterator being overwritten later in the loop.
| pass | |
| self._cur_output_iter = None | |
| continue |
e3e8439 to
44470b4
Compare
|
The idea and motivation look reasonable. Have you done any benchmarks on real workloads? E.g., how much memory can we save? |
Not yet. But we can math it out: Currently we're using per single Map task With this change it will be just the block-size |
47015a3 to
a817c4c
Compare
| try: | ||
| return next(self._input) | ||
| finally: | ||
| self._transformer._report_udf_time(time.perf_counter() - start) |
There was a problem hiding this comment.
UDF timing records time even when exceptions occur
Medium Severity
The new _UDFTimingIterator.__next__ uses a finally block to record UDF time, which means timing is recorded even when next(self._input) raises an exception. The old _udf_timed_iter only recorded timing after a successful next() call (the timing increment came after the next() returned). This changes timing metrics behavior: if UDFs fail repeatedly, the new code accumulates time for each failure while the old code wouldn't record time for failures.
iamjustinhsu
left a comment
There was a problem hiding this comment.
changes make sense to me, do u have any release test results?
| self._active_timer.__enter__() | ||
|
|
||
|
|
||
| class _UnwrappingIterator(Iterator[DataBatch]): |
There was a problem hiding this comment.
In the interest of simplication, do u think we can just fold this implementation into _UserTimingIterator since it's not being used elsewhere and all it does is index into .data?
| self._input = input | ||
| self._transformer = transformer | ||
|
|
||
| def __iter__(self) -> "MapTransformer._UDFTimingIterator": |
There was a problem hiding this comment.
I see this a lot returning self. What are ur thoughts on making it default in the base class?
There was a problem hiding this comment.
We don't have a base class though
| return _UserTimingIterator(_UnwrappingIterator(batch_iter), stats) | ||
|
|
||
|
|
||
| class _UserTimingIterator(Iterator[DataBatch]): |
There was a problem hiding this comment.
Since most classes have a docstring, I think you should also add one here too
|
I'm curious if what u are doing can be simplified if we just do this: ┌─────────────────────────────────────────────────────────────────────────────┐
│ Generator A │
│ ┌────────────────────────────────────────────────────────────────────────┐ │
│ │ def transform_a(inputs): │ │
│ │ for batch in inputs: ◄─── suspended at yield │ │
│ │ batch = process(batch) `batch` PINNED in frame │ │
│ │ yield batch ◄─── `result` PINNED in frame │ │
│ └────────────────────────────────────────────────────────────────────────┘ │ But I guess it still keeps batch and stack in frame |
that I agree. but I'd still prefer verifying the effectiveness with a test. It could be a simple microbenchmark with multiple fused ops + large target block size. Also, another concern I have is that this PR creates too many ad-hoc iterator classes. It'd be nice to unify some of them. E.g. some of them follow the same pattern and can be replaced with this |
|
Thanks for raising great points @raulchen @iamjustinhsu. Will get back to this one after the release |
I've added the unit test for that, which asserts that we don't keep intermediate states anymore. Expectation now is that peak memory should be lower for cases
|
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
…potentially pinning these objects in memory until next iteration Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
… of generators Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com> # Conflicts: # python/ray/data/_internal/execution/operators/map_operator.py # python/ray/data/_internal/execution/operators/map_transformer.py Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
Tidying up Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
a817c4c to
c109d93
Compare
Rebased iters onto MO Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
…-project#60598) ## Description I’ve realized that for fused Map transforms we’re holding a whole stack of intermediate results (batches) simply due to how yield works in Python: - When method yields all of its frame state (local vars) is preserved, pinning all of its intermediate state till the next iteration and not releasing it. - This is in contrast with the pure `Iterator.__next__` method, returning from which, stack frame with all of its intermediate state is destroyed. While this is not an issue most of the time, it's a big problem in cases when multiple Maps are fused: - With multiple operators & corresponding transformations being fused - Intermediate state along with inputs and outputs of each one are pinned until the next iteration - Total size of required heap memory scales up proportionally to the # of operators fused (ie more operators more heap) - This is exacerbated by the fact that now `batch_size` is None by default meaning that the whole block is an input and an output substantially increasing memory requirements. Consider following example: ``` Generator Chain (Problem) ┌─────────────────────────────────────────────────────────────────────────────┐ │ Generator A │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_a(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED in frame │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_b(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED (output of A) │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_c(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED (output of B) │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ to consumer │ └─────────────────────────────────────────────────────────────────────────────┘ Memory at yield point: ┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐ │ input │ A.batch │ A.result│ B.batch │ B.result│ C.batch │ ... ALL PINNED └─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘ ═══════════════════════════════════════════════ Cannot be GC'd until next iteration Iterator Chain (Solution) ┌─────────────────────────────────────────────────────────────────────────────┐ │ Iterator A │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ to consumer │ └─────────────────────────────────────────────────────────────────────────────┘ Memory after return: ┌─────────┬─────────┐ │ input │ output │ ... ONLY 2 objects pinned └─────────┴─────────┘ All intermediates eligible for GC immediately after each __next__ returns Key Difference GENERATOR ITERATOR ───────────────────────────────────────────────────────────────── yield suspends execution vs return completes execution frame stays alive vs frame is destroyed locals pinned until resume vs locals released immediately ┌──────────┐ ┌──────────┐ yield ──►│ SUSPENDED│ return ──►│ COMPLETE │ │ frame │ │ frame │ │ alive │ │destroyed │ └──────────┘ └──────────┘ │ │ ▼ ▼ refs HELD refs RELEASED ``` ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: Adel Nour <ans9868@nyu.edu>
…-project#60598) ## Description I’ve realized that for fused Map transforms we’re holding a whole stack of intermediate results (batches) simply due to how yield works in Python: - When method yields all of its frame state (local vars) is preserved, pinning all of its intermediate state till the next iteration and not releasing it. - This is in contrast with the pure `Iterator.__next__` method, returning from which, stack frame with all of its intermediate state is destroyed. While this is not an issue most of the time, it's a big problem in cases when multiple Maps are fused: - With multiple operators & corresponding transformations being fused - Intermediate state along with inputs and outputs of each one are pinned until the next iteration - Total size of required heap memory scales up proportionally to the # of operators fused (ie more operators more heap) - This is exacerbated by the fact that now `batch_size` is None by default meaning that the whole block is an input and an output substantially increasing memory requirements. Consider following example: ``` Generator Chain (Problem) ┌─────────────────────────────────────────────────────────────────────────────┐ │ Generator A │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_a(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED in frame │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_b(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED (output of A) │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def transform_c(inputs): │ │ │ │ for batch in inputs: ◄─── suspended at yield │ │ │ │ result = process(batch) `batch` PINNED (output of B) │ │ │ │ yield result ◄─── `result` PINNED in frame │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ to consumer │ └─────────────────────────────────────────────────────────────────────────────┘ Memory at yield point: ┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐ │ input │ A.batch │ A.result│ B.batch │ B.result│ C.batch │ ... ALL PINNED └─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘ ═══════════════════════════════════════════════ Cannot be GC'd until next iteration Iterator Chain (Solution) ┌─────────────────────────────────────────────────────────────────────────────┐ │ Iterator A │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │ │ def __next__(self): │ │ │ │ batch = next(self._input) # local var │ │ │ │ result = process(batch) # local var │ │ │ │ return result ◄─── method RETURNS │ │ │ │ locals GO OUT OF SCOPE │ │ │ └────────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ to consumer │ └─────────────────────────────────────────────────────────────────────────────┘ Memory after return: ┌─────────┬─────────┐ │ input │ output │ ... ONLY 2 objects pinned └─────────┴─────────┘ All intermediates eligible for GC immediately after each __next__ returns Key Difference GENERATOR ITERATOR ───────────────────────────────────────────────────────────────── yield suspends execution vs return completes execution frame stays alive vs frame is destroyed locals pinned until resume vs locals released immediately ┌──────────┐ ┌──────────┐ yield ──►│ SUSPENDED│ return ──►│ COMPLETE │ │ frame │ │ frame │ │ alive │ │destroyed │ └──────────┘ └──────────┘ │ │ ▼ ▼ refs HELD refs RELEASED ``` ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>


Description
I’ve realized that for fused Map transforms we’re holding a whole stack of intermediate results (batches) simply due to how yield works in Python:
When method yields all of its frame state (local vars) is preserved, pinning all of its intermediate state till the next iteration and not releasing it.
This is in contrast with the pure
Iterator.__next__method, returning from which, stack frame with all of its intermediate state is destroyed.While this is not an issue most of the time, it's a big problem in cases when multiple Maps are fused:
batch_sizeis None by default meaning that the whole block is an input and an output substantially increasing memory requirements.Consider following example:
Related issues
Additional information