[Data][1/2] Schema inference for non black box UDF logical operators#63387
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements static schema inference for Ray Data logical operators, enabling Dataset.schema() to resolve output schemas without plan execution. It introduces schema-related mixins and implements infer_schema for operators such as Project, Aggregate, Join, and Union, while also enhancing expression and aggregator classes to support type resolution. Feedback indicates that the Aggregate schema inference should be updated to handle multi-key groupings and that PyArrow schema truthiness checks should be replaced with length checks to correctly identify empty schemas.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements static schema inference for Ray Data logical operators, allowing Dataset.schema() to resolve output schemas without falling back to plan execution for non-UDF chains. The changes include refactoring join logic into reusable utilities, introducing schema-inference mixins for logical operators, and enhancing expressions and aggregators to provide static type information. A critical issue was identified regarding a missing functools import in map_operator.py, which would cause a NameError when accessing the cached schema property.
| right_keys=right_on, | ||
| left_suffix=self._left_columns_suffix, | ||
| right_suffix=self._right_columns_suffix, | ||
| def join_tables( |
There was a problem hiding this comment.
the changes are just about pull the instance methods out.
There was a problem hiding this comment.
9514948 to
ddedb4f
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements static schema inference for Ray Data logical operators, enabling schema resolution without materializing blocks. It introduces mixin classes for schema propagation, adds infer_schema methods to various operators, and implements expression-level type resolution. Comprehensive tests are included to verify the new functionality. The review comment correctly identifies a missing functools import in python/ray/data/_internal/logical/operators/map_operator.py that needs to be addressed.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
There are 3 total unresolved issues (including 2 from previous reviews).
Reviewed by Cursor Bugbot for commit 7ae7fbfb02dc6343e829e84658c79a3f604fc615. Configure here.
iamjustinhsu
left a comment
There was a problem hiding this comment.
didn't look through all of the files, but my main concern is this https://github.com/ray-project/ray/pull/63387/changes#r3284806824
| # (``PandasBlockSchema`` chains fall back to ``limit(1)`` execution.) | ||
| if not isinstance(input_schema, pa.Schema): | ||
| return None | ||
| fields = exprlist_to_fields(self.exprs, input_schema) |
There was a problem hiding this comment.
hmm not blocking but i think long-term we want to a concept of resolved and unresolved expressions. Right now those expressions can be anything, and we'll need rules to convert those to resolved expressions. So let's say someone does this:
ds = (
ray.data.read_parquet() # Infer the schema on creation since we have no DSL/IR distinction
.streaming_repartition() # A new logical operator, no expressions
.select_columns(...) # Another logical operator, with an unresolved expression
)Now, we do this:
# Now we should go through these steps:
# 1) This should go through all operators, and resolve their expressions recursively against the schema
# This can lead to a resolved or unresolved (if child schemas unknown, or if there was an error)
# The resolution should be rule-based, ie, one for star expressions, one for attributes, etc...
# 2) If step 1 returns an invalid schema or error, then fallback to limit(1)
ds.schema()Now we do this:
# Since ds.schema() doesn't mutate the dataset, we still need to go through attribute resolution to get
# resolved and unresolved columns. Then we logical optimizations, logical -> physical, physical optimization
ds.materialize()What are ur thoughts on this? (i have a prototype here https://github.com/ray-project/ray/pull/59117/changes#diff-91eaab60fc55ba17ab52a7498bf46e64ef8630689880fbf72288b1f7a9d3d28bR1)
There was a problem hiding this comment.
Ok I will add this Analyzer as a follow up after expanding stars. Did some more reading on datafusion. The rules like TypeCoercion etc. make it clean as to what's being resolved in the planning phase.
1585770 to
4b6491c
Compare
Add ``LogicalOperator.infer_schema()`` overrides to every non-UDF operator so ``Dataset.schema()`` resolves typed pipelines without falling back to a ``limit(1)`` execution. Expressions: * New ``Expr.to_field(schema)`` / ``get_type(schema)`` / ``nullable(schema)`` API. Default delegates to ``data_type.to_arrow_dtype()`` (covers ``LiteralExpr``, ``UDFExpr``, ``DownloadExpr``, ``MonoIncId``, ``RandomExpr``, ``UUIDExpr``); schema-dependent subclasses override. * ``BinaryExpr``/``UnaryExpr`` type promotion via PyArrow compute kernels on empty arrays (same kernels the runtime uses). * ``exprlist_to_fields`` helper expands ``StarExpr`` inline. Logical operators: * ``LogicalOperatorPreservesSchema`` mixin: ``Filter``, ``Sort``, ``Limit``, ``Repartition``, ``RandomShuffle``, ``RandomizeBlocks``, ``StreamingRepartition``, ``StreamingSplit``, ``Write``. * ``LogicalOperatorUnifiesInputSchemas`` mixin: ``Union``, ``Mix``. * ``Project.infer_schema()`` via ``exprlist_to_fields``. * ``Aggregate.infer_schema()`` via ``AggregateFn.output_field``; implemented on ``Count``, ``Sum``, ``Min``, ``Max``, ``Mean``, ``Std``, ``AbsMax``. * ``Zip.infer_schema()`` reuses ``BlockAccessor.zip`` on empty tables. * ``Join.infer_schema()`` reuses the new shared ``join_tables`` utility extracted from ``JoiningAggregation.finalize``. * ``Count.infer_schema()``, ``Download.infer_schema()``. Tests: 37 unit + 18 integration verifying ``ds.schema(fetch_if_missing=False)`` resolves typed chains without execution. Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
Signed-off-by: Goutam <goutam@anyscale.com>
| A list of ``pa.Field`` in projection order, or ``None`` if | ||
| any expression is unresolvable. | ||
| """ | ||
| if input_schema is None: |
There was a problem hiding this comment.
from type annotation, seems like it can't be Optional?
There was a problem hiding this comment.
map_operator.py has this check, I can nuke this line.
|
|
||
| # Any rename whose source isn't in ``input_schema`` falls through | ||
| # here and will fail resolution -> None, matching the runtime's | ||
| # "column not found" error. | ||
| for expr in (*rename_by_source_name.values(), *non_rename_exprs): | ||
| if not _resolve_and_upsert(expr): | ||
| return None |
There was a problem hiding this comment.
wait so this code will also run if has_star is True? I'm under the impression that the previous if block will handle that case
There was a problem hiding this comment.
Yes but they handle 2 different cases. The 1st loop is resolving and upserting the input schema. The 2nd loop is handling the other expressions in the list.
Take this example:
[star(), (a+b).alias("sum")]:
- Loop 1 emits a, b (from input_schema), and never executes line 1890 (no renames).
- Loop 2 iterates [] + [(a+b).alias("sum")] → appends sum.
| op_fn = _ARROW_EXPR_OPS_MAP.get(op) | ||
| if op_fn is None: | ||
| return None |
There was a problem hiding this comment.
should this be an assertion?
There was a problem hiding this comment.
Yea technically we can't evaluate if the operation isn't supported
| .with_column("s", col("a") + col("b")) | ||
| .groupby("k") | ||
| .aggregate(Sum("a"), Mean("b")) | ||
| .sort("k") |
There was a problem hiding this comment.
i forget -- does "s" get propagated?
There was a problem hiding this comment.
no cause of the groupby().agg() so only k, sum(a), mean(b) will stay post-execution
| ds_a = ray.data.read_parquet(str(parquet_path)) | ||
| ds_b = ray.data.read_parquet(str(parquet_path)) | ||
| ds = ds_a.union(ds_b) |
There was a problem hiding this comment.
u should probably do a select(a,b) for ds_a, and select(b,k) for ds_b, so that u see the union?
There was a problem hiding this comment.
Added it
| @@ -332,6 +344,34 @@ def _validate(self, schema: Optional["Schema"]) -> None: | |||
| SortKey(self._target_col_name).validate_schema(schema) | |||
|
|
|||
|
|
|||
| def _agg_output_field( | |||
| name: str, | |||
There was a problem hiding this comment.
i think it would be good to add a docstring for name and target_col? Is this partition column?
There was a problem hiding this comment.
expanded the doc string
| @@ -902,6 +960,16 @@ def combine( | |||
| ) -> SupportsRichComparisonType: | |||
| return max(current_accumulator, new) | |||
|
|
|||
| def output_field(self, input_schema: "pa.Schema") -> Optional["pa.Field"]: | |||
There was a problem hiding this comment.
how come u can't use _agg_output_field here?
There was a problem hiding this comment.
We can
return _agg_output_field(
self.name,
input_schema,
self._target_col_name,
lambda a: pc.max(pc.abs(a)),
)
4b6491c to
b00c17f
Compare
Signed-off-by: Goutam <goutam@anyscale.com>
…ay-project#63387) ## Description This PR teaches Ray Data's logical plan to infer output schemas for non-UDF operators before execution. Previously, a typed pipeline could still make `Dataset.schema()` fall back to a `limit(1)` execution whenever an intermediate logical operator could not describe its output schema. With these overrides, `ds.schema(fetch_if_missing=False)` can resolve through typed pipelines made from projections, filters, shuffles/repartitioning, aggregate/count, union/mix, zip, join, download, write, and streaming split/repartition operators without sampling data. This does not apply to black-box UDF transforms such as `map`, `map_batches`, and similar APIs. Expression UDF schemas still come from their declared `return_dtype`; this PR does not infer result types from UDF implementation code. ## Related issues N/A ## Additional information The main pieces are: * Adds `LogicalOperator.infer_schema()` plus reusable mixins for operators that preserve or unify input schemas. * Adds expression-level field resolution through `Expr.to_field()`, `get_type()`, and `nullable()`, including projection handling for `*`, aliases, renames, and upserts. * Adds aggregate output fields for built-in aggregations such as count, sum, min, max, mean, std, and abs max. * Reuses runtime table logic for zip and join schema inference on empty Arrow tables so inferred schemas match execution behavior. * Covers the new behavior with 37 expression/unit cases and 18 integration cases that verify `ds.schema(fetch_if_missing=False)` resolves without execution. --------- Signed-off-by: Goutam <goutam@anyscale.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…ay-project#63387) ## Description This PR teaches Ray Data's logical plan to infer output schemas for non-UDF operators before execution. Previously, a typed pipeline could still make `Dataset.schema()` fall back to a `limit(1)` execution whenever an intermediate logical operator could not describe its output schema. With these overrides, `ds.schema(fetch_if_missing=False)` can resolve through typed pipelines made from projections, filters, shuffles/repartitioning, aggregate/count, union/mix, zip, join, download, write, and streaming split/repartition operators without sampling data. This does not apply to black-box UDF transforms such as `map`, `map_batches`, and similar APIs. Expression UDF schemas still come from their declared `return_dtype`; this PR does not infer result types from UDF implementation code. ## Related issues N/A ## Additional information The main pieces are: * Adds `LogicalOperator.infer_schema()` plus reusable mixins for operators that preserve or unify input schemas. * Adds expression-level field resolution through `Expr.to_field()`, `get_type()`, and `nullable()`, including projection handling for `*`, aliases, renames, and upserts. * Adds aggregate output fields for built-in aggregations such as count, sum, min, max, mean, std, and abs max. * Reuses runtime table logic for zip and join schema inference on empty Arrow tables so inferred schemas match execution behavior. * Covers the new behavior with 37 expression/unit cases and 18 integration cases that verify `ds.schema(fetch_if_missing=False)` resolves without execution. --------- Signed-off-by: Goutam <goutam@anyscale.com> Co-authored-by: Cursor <cursoragent@cursor.com>

Description
This PR teaches Ray Data's logical plan to infer output schemas for non-UDF operators before execution.
Previously, a typed pipeline could still make
Dataset.schema()fall back to alimit(1)execution whenever an intermediate logical operator could not describe its output schema. With these overrides,ds.schema(fetch_if_missing=False)can resolve through typed pipelines made from projections, filters, shuffles/repartitioning, aggregate/count, union/mix, zip, join, download, write, and streaming split/repartition operators without sampling data.This does not apply to black-box UDF transforms such as
map,map_batches, and similar APIs. Expression UDF schemas still come from their declaredreturn_dtype; this PR does not infer result types from UDF implementation code.Related issues
N/A
Additional information
The main pieces are:
LogicalOperator.infer_schema()plus reusable mixins for operators that preserve or unify input schemas.Expr.to_field(),get_type(), andnullable(), including projection handling for*, aliases, renames, and upserts.ds.schema(fetch_if_missing=False)resolves without execution.