Skip to content
2 changes: 1 addition & 1 deletion doc/source/data/working-with-images.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ To view the full list of supported file formats, see the
Column Type
------ ----
image_url string
bytes null
bytes binary

.. tab-item:: NumPy

Expand Down
14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,20 @@ py_test(
],
)

py_test(
name = "test_infer_schema",
size = "small",
srcs = ["tests/test_infer_schema.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_task_pool_map_operator",
size = "small",
Expand Down
498 changes: 265 additions & 233 deletions python/ray/data/_internal/execution/operators/join.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions python/ray/data/_internal/logical/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .logical_operator import (
LogicalOperator,
LogicalOperatorPreservesSchema,
LogicalOperatorSupportsPredicatePassThrough,
LogicalOperatorSupportsPredicatePushdown,
LogicalOperatorSupportsProjectionPushdown,
LogicalOperatorUnifiesInputSchemas,
PredicatePassThroughBehavior,
)
from .logical_plan import LogicalPlan
Expand All @@ -21,8 +23,10 @@
"Plan",
"Rule",
"SourceOperator",
"LogicalOperatorPreservesSchema",
"LogicalOperatorSupportsProjectionPushdown",
"LogicalOperatorSupportsPredicatePushdown",
"LogicalOperatorSupportsPredicatePassThrough",
"LogicalOperatorUnifiesInputSchemas",
"PredicatePassThroughBehavior",
]
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,47 @@ def apply_projection(
return self


class LogicalOperatorPreservesSchema(LogicalOperator):
"""Mixin for operators whose output column layout is identical to their
single input's. Provides a default ``infer_schema()`` that delegates to
the input. Use for ops like ``Filter``, ``Sort``, ``Limit``, etc., that
only re-order or filter rows.

List this mixin last in the bases of subclasses so the concrete operator
base (e.g., ``AbstractMap``, ``AbstractAllToAll``) drives ``__init__`` /
``super()`` chains.
"""

def infer_schema(self) -> Optional["Schema"]:
assert len(self.input_dependencies) == 1, len(self.input_dependencies)
return self.input_dependencies[0].infer_schema()


class LogicalOperatorUnifiesInputSchemas(LogicalOperator):
"""Mixin for n-ary operators whose output schema is the unification of
all inputs' schemas (e.g., ``Union``, ``Mix``). Provides a default
``infer_schema()`` that returns the result of
``unify_schemas_with_validation`` over each input's schema, or
``None`` if any input's schema is unresolvable.

List this mixin last in the bases of subclasses so the concrete operator
base (e.g., ``NAry``) drives ``__init__`` / ``super()`` chains.
"""

def infer_schema(self) -> Optional["Schema"]:
import pyarrow as pa

from ray.data._internal.util import unify_schemas_with_validation

input_schemas = [op.infer_schema() for op in self.input_dependencies]
if not all(isinstance(s, pa.Schema) for s in input_schemas):
return None
try:
return unify_schemas_with_validation(input_schemas)
except (pa.ArrowTypeError, pa.ArrowInvalid):
return None
Comment thread
cursor[bot] marked this conversation as resolved.


class LogicalOperatorSupportsPredicatePushdown(LogicalOperator):
"""Mixin for reading operators supporting predicate pushdown"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ray.data._internal.logical.interfaces import (
LogicalOperator,
LogicalOperatorPreservesSchema,
LogicalOperatorSupportsPredicatePassThrough,
PredicatePassThroughBehavior,
)
Expand All @@ -14,7 +15,6 @@
from ray.data.block import BlockMetadata

if TYPE_CHECKING:

from ray.data.block import Schema

__all__ = [
Expand Down Expand Up @@ -70,7 +70,11 @@ def num_outputs(self) -> Optional[int]:


@dataclass(frozen=True, repr=False, eq=False)
class RandomizeBlocks(AbstractAllToAll, LogicalOperatorSupportsPredicatePassThrough):
class RandomizeBlocks(
AbstractAllToAll,
LogicalOperatorSupportsPredicatePassThrough,
LogicalOperatorPreservesSchema,
):
"""Logical operator for randomize_block_order."""

seed_config: Optional[RandomSeedConfig] = None
Expand All @@ -91,20 +95,17 @@ def infer_metadata(self) -> "BlockMetadata":
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_metadata()

def infer_schema(
self,
) -> Optional["Schema"]:
assert len(self.input_dependencies) == 1, len(self.input_dependencies)
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_schema()

def predicate_passthrough_behavior(self) -> PredicatePassThroughBehavior:
# Randomizing block order doesn't affect filtering correctness
return PredicatePassThroughBehavior.PASSTHROUGH


@dataclass(frozen=True, repr=False, eq=False)
class RandomShuffle(AbstractAllToAll, LogicalOperatorSupportsPredicatePassThrough):
class RandomShuffle(
AbstractAllToAll,
LogicalOperatorSupportsPredicatePassThrough,
LogicalOperatorPreservesSchema,
):
"""Logical operator for random_shuffle."""

name: InitVar[str] = "RandomShuffle"
Expand Down Expand Up @@ -140,20 +141,17 @@ def infer_metadata(self) -> "BlockMetadata":
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_metadata()

def infer_schema(
self,
) -> Optional["Schema"]:
assert len(self.input_dependencies) == 1, len(self.input_dependencies)
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_schema()

def predicate_passthrough_behavior(self) -> PredicatePassThroughBehavior:
# Random shuffle doesn't affect filtering correctness
return PredicatePassThroughBehavior.PASSTHROUGH


@dataclass(frozen=True, repr=False, eq=False)
class Repartition(AbstractAllToAll, LogicalOperatorSupportsPredicatePassThrough):
class Repartition(
AbstractAllToAll,
LogicalOperatorSupportsPredicatePassThrough,
LogicalOperatorPreservesSchema,
):
"""Logical operator for repartition."""

num_outputs: InitVar[int]
Expand Down Expand Up @@ -193,20 +191,17 @@ def infer_metadata(self) -> "BlockMetadata":
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_metadata()

def infer_schema(
self,
) -> Optional["Schema"]:
assert len(self.input_dependencies) == 1, len(self.input_dependencies)
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_schema()

def predicate_passthrough_behavior(self) -> PredicatePassThroughBehavior:
# Repartition doesn't affect filtering correctness
return PredicatePassThroughBehavior.PASSTHROUGH


@dataclass(frozen=True, repr=False, eq=False)
class Sort(AbstractAllToAll, LogicalOperatorSupportsPredicatePassThrough):
class Sort(
AbstractAllToAll,
LogicalOperatorSupportsPredicatePassThrough,
LogicalOperatorPreservesSchema,
):
"""Logical operator for sort."""

sort_key: SortKey
Expand Down Expand Up @@ -234,13 +229,6 @@ def infer_metadata(self) -> "BlockMetadata":
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_metadata()

def infer_schema(
self,
) -> Optional["Schema"]:
assert len(self.input_dependencies) == 1, len(self.input_dependencies)
assert isinstance(self.input_dependencies[0], LogicalOperator)
return self.input_dependencies[0].infer_schema()

def predicate_passthrough_behavior(self) -> PredicatePassThroughBehavior:
# Sort doesn't affect filtering correctness
return PredicatePassThroughBehavior.PASSTHROUGH
Expand All @@ -250,7 +238,7 @@ def predicate_passthrough_behavior(self) -> PredicatePassThroughBehavior:
class Aggregate(AbstractAllToAll):
"""Logical operator for aggregate."""

key: Optional[str]
key: Optional[str | List[str]]
aggs: List[AggregateFn]
num_partitions: Optional[int] = None
batch_format: Optional[str] = "default"
Expand All @@ -271,3 +259,29 @@ def __post_init__(self):
],
)
object.__setattr__(self, "_num_outputs", None)

def infer_schema(self) -> Optional["Schema"]:
# Output = key field(s) from input schema + one field per aggregator.
# Returns None if any aggregator can't declare its output field
# (callers fall back to limit(1)).
import pyarrow as pa

assert len(self.input_dependencies) == 1, len(self.input_dependencies)
input_schema = self.input_dependencies[0].infer_schema()
if not isinstance(input_schema, pa.Schema):
return None

fields: List[pa.Field] = []
if self.key is not None:
keys = self.key if isinstance(self.key, list) else [self.key]
for key in keys:
try:
fields.append(input_schema.field(key))
except (KeyError, TypeError, ValueError):
return None
for agg in self.aggs:
f = agg.output_field(input_schema)
if f is None:
return None
fields.append(f)
return pa.schema(fields)
12 changes: 11 additions & 1 deletion python/ray/data/_internal/logical/operators/count_operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import TYPE_CHECKING, Optional

from ray.data._internal.logical.interfaces import LogicalOperator

if TYPE_CHECKING:
from ray.data.block import Schema

__all__ = [
"Count",
]
Expand All @@ -28,3 +31,10 @@ def __post_init__(self):
@property
def num_outputs(self) -> Optional[int]:
return self._num_outputs

def infer_schema(self) -> Optional["Schema"]:
# Fixed output: one row per partial count with a single ``__num_rows``
# int64 column.
import pyarrow as pa

return pa.schema([pa.field(self.COLUMN_NAME, pa.int64(), nullable=False)])
36 changes: 36 additions & 0 deletions python/ray/data/_internal/logical/operators/join_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,39 @@ def _get_referenced_columns(self, expr: "Expr") -> set[str]:
visitor = _ColumnReferenceCollector()
visitor.visit(expr)
return set(visitor.get_column_refs())

def infer_schema(self) -> Optional["Schema"]:
"""Infer the output schema by running the shared ``join_tables``
utility on empty tables built from the input schemas. The same
utility runs at execution time, so plan-time and runtime schemas
agree by construction.
"""
import pyarrow as pa

from ray.data._internal.execution.operators.join import join_tables

left_schema = self.input_dependencies[0].infer_schema()
right_schema = self.input_dependencies[1].infer_schema()
if not isinstance(left_schema, pa.Schema) or not isinstance(
right_schema, pa.Schema
):
return None

join_type_enum = (
self.join_type
if isinstance(self.join_type, JoinType)
else JoinType(self.join_type)
)
try:
joined = join_tables(
left_schema.empty_table(),
right_schema.empty_table(),
join_type=join_type_enum,
left_key_col_names=tuple(self.left_key_columns),
right_key_col_names=tuple(self.right_key_columns),
left_columns_suffix=self.left_columns_suffix,
right_columns_suffix=self.right_columns_suffix,
)
except (pa.ArrowTypeError, pa.ArrowInvalid, pa.ArrowKeyError, ValueError):
return None
return joined.schema
Loading
Loading