Skip to content

[data] add jax data iterator#61630

Merged
aslonnie merged 16 commits into
ray-project:masterfrom
siyuanfoundation:jax-dataloader
Apr 14, 2026
Merged

[data] add jax data iterator#61630
aslonnie merged 16 commits into
ray-project:masterfrom
siyuanfoundation:jax-dataloader

Conversation

@siyuanfoundation

@siyuanfoundation siyuanfoundation commented Mar 10, 2026

Copy link
Copy Markdown
Contributor

Description

This PR introduces the iter_jax_batches API for Ray Data, enabling seamless integration between Ray Datasets and JAX architectures in distributed training scenarios. This provides first-class support for processing data within JaxTrainer workloads.

Related issues

#55162

Additional information

This PR implements Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html:
It loads ray data into devices using pure data parallel to be consistent with JaxTrainer (which is a DataParallelTrainer).
Users can reshard it to the desired shape after getting the DDP batch.

Here is an example of how to use it for jax inference

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tpu import slice_placement_group

# Connect to the Ray cluster.
if not ray.is_initialized():
    ray.init()

# 1. Reserve one v6e TPU slice (4x4, 16 chips, 4 VMs).
print("Reserving TPU slice...")
slice_handle = slice_placement_group(
    topology="4x4",
    accelerator_version="v6e",
    num_slices=1,
    resources_per_bundle={"TPU": 4},
)
slice_pg = slice_handle.placement_group

print("Waiting for placement group to be ready...")
ray.get(slice_pg.ready(), timeout=600)
print("Placement group ready.")

num_workers = slice_handle.num_bundles


@ray.remote(num_cpus=1, resources={"TPU": 4})
class JAXInferenceWorker:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.ip = ray.util.get_node_ip_address()

    def get_ip(self):
        return self.ip

    def initialize_jax(self, coordinator_address):
        jax.distributed.initialize(
            coordinator_address=f"{coordinator_address}:1234",
            num_processes=self.world_size,
            process_id=self.rank,
        )
        devices = jax.devices()
        self.mesh = Mesh(devices, axis_names=("data",))

        # Initialize parameters for binary prediction (Logistic Regression)
        input_dim = 8
        output_dim = 1
        key = jax.random.PRNGKey(42)
        k1, k2 = jax.random.split(key)

        self.param_sharding = NamedSharding(self.mesh, P())
        self.W = jax.device_put(
            jax.random.normal(k1, (input_dim, output_dim)), self.param_sharding
        )
        self.b = jax.device_put(
            jax.random.normal(k2, (output_dim,)), self.param_sharding
        )

        print(f"Worker {self.rank}: JAX initialized.")

    def run_inference(self, ds_shard):
        import jax
        import numpy as np

        # Logistic Regression Model
        @jax.jit
        def model_fn(x, W, b):
            logits = jnp.matmul(x, W) + b
            probs = jax.nn.sigmoid(logits)
            return (probs > 0.5).astype(jnp.int32)

        print(f"Worker {self.rank}: Starting inference...")

        # batch_size=4 rows per chip -> global batch size 64 (16 chips)
        batch_iterator = ds_shard.iter_jax_batches(
            batch_size=4, synchronize_batches=True, drop_last=True
        )

        data_sharding = NamedSharding(self.mesh, P("data"))

        local_predictions = []
        for i, batch in enumerate(batch_iterator):
            x = jax.device_put(batch["features"], data_sharding)
            preds = model_fn(x, self.W, self.b)

            # Each worker returns ONLY its local addressable data.
            # On a VM with 4 chips, preds.addressable_shards contains 4 shards.
            # We concatenate them to get the data corresponding to this host.
            host_shard = np.concatenate(
                [np.array(s.data) for s in preds.addressable_shards], axis=0
            )
            local_predictions.append(host_shard)

            if i % 2 == 0:
                print(f"Worker {self.rank}: Processed batch {i}...")

        return np.concatenate(local_predictions, axis=0)


# 2. Launch workers
workers = [
    JAXInferenceWorker.options(
        scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=slice_pg)
    ).remote(rank=i, world_size=num_workers)
    for i in range(num_workers)
]

# 3. Coordinate JAX setup
worker_ips = ray.get([w.get_ip.remote() for w in workers])
coordinator_address = worker_ips[0]
ray.get([w.initialize_jax.remote(coordinator_address) for w in workers])

# 4. Prepare data (8 features)
print("Preparing dataset...")


def generate_data(batch):
    return {"features": np.random.randn(len(batch["id"]), 8).astype(np.float32)}


# 64 rows total -> exactly 1 global batch of 64.
# Each of the 4 workers gets 16 rows.
ds = ray.data.range(64).map_batches(generate_data)
shards = ds.split(num_workers, equal=True)

# 5. Execute inference
print("Launching inference tasks...")
results = ray.get(
    [workers[i].run_inference.remote(shards[i]) for i in range(num_workers)]
)

# 6. Show results (Aggregated on the driver)
# results is a list of 4 numpy arrays, each of shape (16, 1)
all_preds = np.concatenate(results, axis=0)

print("\n--- Inference Results (Binary Predictions) ---")
print(f"Total predictions collected: {len(all_preds)}")
print(f"Worker result shapes: {[r.shape for r in results]}")
print("Sample predictions (first 20):")
print(all_preds[:20].flatten())

# Cleanup
slice_handle.shutdown()
@siyuanfoundation siyuanfoundation requested review from a team as code owners March 10, 2026 18:17
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/util/jax_util.py Outdated

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces iter_jax_batches to Ray Data, providing a way to integrate Ray Datasets with JAX for distributed training. The changes include adding the new API to Dataset and DataIterator, along with utility functions for JAX tensor conversion and comprehensive tests. The implementation looks solid and follows existing patterns in Ray Data.

I have a couple of suggestions:

  1. The example in the iter_jax_batches docstring in dataset.py is not runnable and could be improved for clarity.
  2. There is some redundant code in the new jax_util.py file that can be cleaned up.

Overall, this is a great addition that improves JAX integration with Ray Data.

Note: Security Review did not run due to the size of the PR.

Comment thread python/ray/data/dataset.py
Comment thread python/ray/data/util/jax_util.py
@ray-gardener ray-gardener Bot added data Ray Data-related issues community-contribution Contributed by the community labels Mar 10, 2026
Comment thread python/ray/data/util/jax_util.py
Comment thread python/ray/data/util/jax_util.py Outdated
@siyuanfoundation siyuanfoundation force-pushed the jax-dataloader branch 2 times, most recently from 23b6c9f to 03108fb Compare March 12, 2026 18:46
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/dataset.py
Comment thread python/ray/data/tests/test_dataset_iter.py Outdated
Comment thread python/ray/data/dataset.py
Comment thread python/ray/train/v2/tests/test_jax_trainer.py
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/dataset.py
Comment thread python/ray/data/util/jax_util.py Outdated
@ryanaoleary

Copy link
Copy Markdown
Contributor
Comment thread python/ray/data/iterator.py Outdated
Comment thread python/ray/train/v2/tests/test_jax_trainer.py Outdated
Comment thread python/ray/data/iterator.py Outdated
num_total_devices = jax.device_count()
num_hosts = jax.process_count()

if batch_size is not None and batch_size % num_total_devices != 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use batch_size here as the per process batch_size?

I think currently the batch_size in iter_torch_batches means per process batch_size, with this change, it will cause some divergence between the iter_torch_batches vs iter_jax_batches,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. Reverted the last commit

@liulehui liulehui left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tysm siyuan!

Comment thread python/ray/data/iterator.py Outdated

num_local_devices = jax.local_device_count()

if batch_size is not None and batch_size % num_local_devices != 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is needed because in the util when we first calculate the physical layout with jax.devices() to view it as a flat 1D array.

I am a bit confused for this one e.g.

say we have a set up of a v5e 2x4 slice, with 2 hosts, each host has 4 chips.
we would like to express the mesh as (data, model) = (2, 4) and global_batch_size = 2, basically means, each process will get its own shard of data by ray data with 1 row in it.
since model parallel is 4, means the 4 chips on the host will see the same 1 row of data.

this config should be valid, while it will fail here since local_batch_size = 1 and num_local_devices = 4.

I think the local_batch_size should be % by data dimension or jax.processes().

Please correct me if I understand wrong!

@siyuanfoundation siyuanfoundation Mar 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the next improvement, you are right. But the current implementation, we load the data in pure data parallel form first, then reshard (Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html). So it requires the data can be sharded across devices.
In the long run, we would like to add advanced parallelism, but that would mean changing DataParallelTrainer or moving JaxTrainer away from it. We do plan to work on it, but it would take considerable more effort. In the meantime, can we add something easier to implement first?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood.

@liulehui liulehui Mar 23, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, do you think it is acceptable to split the utils into 2 like below in the short term:

batch = iter_jax_batches(
    *,
    batch_size=...,
    synchronize_batches=True,
    ...
)

from ray.data.util.jax_util import reshard_jax_batch
batch = reshard_jax_batch(batch, named_sharding)

this way we will have:

  1. iterator = loading/conversion
  2. helper = advanced reshaping

also can give us the benefit to not pass in named_sharding to iter_jax_batches()

WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed named_sharding argument.
For the helper function, I think the user can just call jax.device_put directly, like in the test example

    for batch in ds_shard.iter_jax_batches(
        batch_size=16,
        drop_last=drop_last,
    ):
        arr = jax.device_put(batch["features"], named_sharding)
        assert arr.sharding == named_sharding
Comment thread python/ray/data/util/jax_util.py Outdated
Comment thread python/ray/data/dataset.py Outdated
def iter_jax_batches(
self,
*,
named_sharding: "jax.sharding.NamedSharding" = None, # noqa: F821

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the long run, when we would like to express advanced parallelism, we can pass in replicate/rank metadata in for both jax/torch batches.

I am wondering if we should not add this arg first, and only assumes ddp training on jax trainer, which is the same for torch as of today, same behavior as pass in None here.

but I am also ok to keep it for now to allow user for this experimental api and maybe deprecate it later for end of March.

WDYT? @justinvyu @xinyuangui2

Comment thread python/ray/data/util/jax_util.py Outdated
)
return

# Multi-host synchronization with lookahead

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty, I think this part makes sense to me, just wann confirm my understanding.

Basically IIUC, say we have 200 examples on 2x4 TPU V5E, 2 jax processes, global batch size = 32, local batch size = 16.
each process will have 6x16 full batches, and a 6 row batch left, while it is ok for torch to train, it is not ok for Jax, since it require same shape.

even if a 8 row batch left is valid, it will trigger a recompilation.

I think Ray data also tries to evenly distributed batches, but I agree it is not as strict as Jax here. but I am wondering how often can this recompilation happen, if ray data does a good job evenly distributed the shape and also drop_last=True, how about we set the default synchronize_batches = False, and suggest flip it to True if we see a ton of recompliation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently if there is a 8 row batch left, it would keep the batch because 8 can be divided by 2x4 devices. It would drop the 6 if drop_last=True because jax requires the same shape across devices.
I don't think recompilation would be a big deal since it happens mostly likely at the end.
If we can trust Ray data to evenly distributed batches mostly, I have no problem with settingsynchronize_batches=False.

@siyuanfoundation siyuanfoundation force-pushed the jax-dataloader branch 2 times, most recently from 8fe6b82 to 5273fc7 Compare April 14, 2026 01:59
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
[data] add option to skip synchronize_batches for jax iterator and sync optimization with lookahead

Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
@aslonnie aslonnie merged commit cf3939d into ray-project:master Apr 14, 2026
5 of 6 checks passed
HLDKNotFound pushed a commit to chichic21039/ray that referenced this pull request Apr 22, 2026
## Description
This PR introduces the `iter_jax_batches` API for Ray Data, enabling
seamless integration between Ray Datasets and JAX architectures in
distributed training scenarios. This provides first-class support for
processing data within `JaxTrainer` workloads.

## Related issues
ray-project#55162

## Additional information
This PR implements Option 4 in
https://docs.jax.dev/en/latest/distributed_data_loading.html:
It loads ray data into devices using pure data parallel to be consistent
with `JaxTrainer` (which is a `DataParallelTrainer`).
Users can reshard it to the desired shape after getting the DDP batch.

Here is an example of how to use it for jax inference
```python
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tpu import slice_placement_group

# Connect to the Ray cluster.
if not ray.is_initialized():
    ray.init()

# 1. Reserve one v6e TPU slice (4x4, 16 chips, 4 VMs).
print("Reserving TPU slice...")
slice_handle = slice_placement_group(
    topology="4x4",
    accelerator_version="v6e",
    num_slices=1,
    resources_per_bundle={"TPU": 4},
)
slice_pg = slice_handle.placement_group

print("Waiting for placement group to be ready...")
ray.get(slice_pg.ready(), timeout=600)
print("Placement group ready.")

num_workers = slice_handle.num_bundles


@ray.remote(num_cpus=1, resources={"TPU": 4})
class JAXInferenceWorker:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.ip = ray.util.get_node_ip_address()

    def get_ip(self):
        return self.ip

    def initialize_jax(self, coordinator_address):
        jax.distributed.initialize(
            coordinator_address=f"{coordinator_address}:1234",
            num_processes=self.world_size,
            process_id=self.rank,
        )
        devices = jax.devices()
        self.mesh = Mesh(devices, axis_names=("data",))

        # Initialize parameters for binary prediction (Logistic Regression)
        input_dim = 8
        output_dim = 1
        key = jax.random.PRNGKey(42)
        k1, k2 = jax.random.split(key)

        self.param_sharding = NamedSharding(self.mesh, P())
        self.W = jax.device_put(
            jax.random.normal(k1, (input_dim, output_dim)), self.param_sharding
        )
        self.b = jax.device_put(
            jax.random.normal(k2, (output_dim,)), self.param_sharding
        )

        print(f"Worker {self.rank}: JAX initialized.")

    def run_inference(self, ds_shard):
        import jax
        import numpy as np

        # Logistic Regression Model
        @jax.jit
        def model_fn(x, W, b):
            logits = jnp.matmul(x, W) + b
            probs = jax.nn.sigmoid(logits)
            return (probs > 0.5).astype(jnp.int32)

        print(f"Worker {self.rank}: Starting inference...")

        # batch_size=4 rows per chip -> global batch size 64 (16 chips)
        batch_iterator = ds_shard.iter_jax_batches(
            batch_size=4, synchronize_batches=True, drop_last=True
        )

        data_sharding = NamedSharding(self.mesh, P("data"))

        local_predictions = []
        for i, batch in enumerate(batch_iterator):
            x = jax.device_put(batch["features"], data_sharding)
            preds = model_fn(x, self.W, self.b)

            # Each worker returns ONLY its local addressable data.
            # On a VM with 4 chips, preds.addressable_shards contains 4 shards.
            # We concatenate them to get the data corresponding to this host.
            host_shard = np.concatenate(
                [np.array(s.data) for s in preds.addressable_shards], axis=0
            )
            local_predictions.append(host_shard)

            if i % 2 == 0:
                print(f"Worker {self.rank}: Processed batch {i}...")

        return np.concatenate(local_predictions, axis=0)


# 2. Launch workers
workers = [
    JAXInferenceWorker.options(
        scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=slice_pg)
    ).remote(rank=i, world_size=num_workers)
    for i in range(num_workers)
]

# 3. Coordinate JAX setup
worker_ips = ray.get([w.get_ip.remote() for w in workers])
coordinator_address = worker_ips[0]
ray.get([w.initialize_jax.remote(coordinator_address) for w in workers])

# 4. Prepare data (8 features)
print("Preparing dataset...")


def generate_data(batch):
    return {"features": np.random.randn(len(batch["id"]), 8).astype(np.float32)}


# 64 rows total -> exactly 1 global batch of 64.
# Each of the 4 workers gets 16 rows.
ds = ray.data.range(64).map_batches(generate_data)
shards = ds.split(num_workers, equal=True)

# 5. Execute inference
print("Launching inference tasks...")
results = ray.get(
    [workers[i].run_inference.remote(shards[i]) for i in range(num_workers)]
)

# 6. Show results (Aggregated on the driver)
# results is a list of 4 numpy arrays, each of shape (16, 1)
all_preds = np.concatenate(results, axis=0)

print("\n--- Inference Results (Binary Predictions) ---")
print(f"Total predictions collected: {len(all_preds)}")
print(f"Worker result shapes: {[r.shape for r in results]}")
print("Sample predictions (first 20):")
print(all_preds[:20].flatten())

# Cleanup
slice_handle.shutdown()
```

---------

Signed-off-by: siyuanfoundation <sizhang@google.com>
ryanaoleary added a commit that referenced this pull request May 12, 2026
## Description
#61630 added support for JAX to
Ray data, specifically implementing a `iter_jax_batches` util to yield
natively sharded `jax.Arrays`. This provides first-class support for
processing data within JaxTrainer workloads. This PR updates an existing
GPT-2 guide using the `JaxTrainer` to showcase how this new util could
simplify the Train code.

## Related issues
#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
## Description
This PR introduces the `iter_jax_batches` API for Ray Data, enabling
seamless integration between Ray Datasets and JAX architectures in
distributed training scenarios. This provides first-class support for
processing data within `JaxTrainer` workloads.

## Related issues
ray-project#55162

## Additional information
This PR implements Option 4 in
https://docs.jax.dev/en/latest/distributed_data_loading.html:
It loads ray data into devices using pure data parallel to be consistent
with `JaxTrainer` (which is a `DataParallelTrainer`).
Users can reshard it to the desired shape after getting the DDP batch.

Here is an example of how to use it for jax inference
```python
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tpu import slice_placement_group

# Connect to the Ray cluster.
if not ray.is_initialized():
    ray.init()

# 1. Reserve one v6e TPU slice (4x4, 16 chips, 4 VMs).
print("Reserving TPU slice...")
slice_handle = slice_placement_group(
    topology="4x4",
    accelerator_version="v6e",
    num_slices=1,
    resources_per_bundle={"TPU": 4},
)
slice_pg = slice_handle.placement_group

print("Waiting for placement group to be ready...")
ray.get(slice_pg.ready(), timeout=600)
print("Placement group ready.")

num_workers = slice_handle.num_bundles


@ray.remote(num_cpus=1, resources={"TPU": 4})
class JAXInferenceWorker:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.ip = ray.util.get_node_ip_address()

    def get_ip(self):
        return self.ip

    def initialize_jax(self, coordinator_address):
        jax.distributed.initialize(
            coordinator_address=f"{coordinator_address}:1234",
            num_processes=self.world_size,
            process_id=self.rank,
        )
        devices = jax.devices()
        self.mesh = Mesh(devices, axis_names=("data",))

        # Initialize parameters for binary prediction (Logistic Regression)
        input_dim = 8
        output_dim = 1
        key = jax.random.PRNGKey(42)
        k1, k2 = jax.random.split(key)

        self.param_sharding = NamedSharding(self.mesh, P())
        self.W = jax.device_put(
            jax.random.normal(k1, (input_dim, output_dim)), self.param_sharding
        )
        self.b = jax.device_put(
            jax.random.normal(k2, (output_dim,)), self.param_sharding
        )

        print(f"Worker {self.rank}: JAX initialized.")

    def run_inference(self, ds_shard):
        import jax
        import numpy as np

        # Logistic Regression Model
        @jax.jit
        def model_fn(x, W, b):
            logits = jnp.matmul(x, W) + b
            probs = jax.nn.sigmoid(logits)
            return (probs > 0.5).astype(jnp.int32)

        print(f"Worker {self.rank}: Starting inference...")

        # batch_size=4 rows per chip -> global batch size 64 (16 chips)
        batch_iterator = ds_shard.iter_jax_batches(
            batch_size=4, synchronize_batches=True, drop_last=True
        )

        data_sharding = NamedSharding(self.mesh, P("data"))

        local_predictions = []
        for i, batch in enumerate(batch_iterator):
            x = jax.device_put(batch["features"], data_sharding)
            preds = model_fn(x, self.W, self.b)

            # Each worker returns ONLY its local addressable data.
            # On a VM with 4 chips, preds.addressable_shards contains 4 shards.
            # We concatenate them to get the data corresponding to this host.
            host_shard = np.concatenate(
                [np.array(s.data) for s in preds.addressable_shards], axis=0
            )
            local_predictions.append(host_shard)

            if i % 2 == 0:
                print(f"Worker {self.rank}: Processed batch {i}...")

        return np.concatenate(local_predictions, axis=0)


# 2. Launch workers
workers = [
    JAXInferenceWorker.options(
        scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=slice_pg)
    ).remote(rank=i, world_size=num_workers)
    for i in range(num_workers)
]

# 3. Coordinate JAX setup
worker_ips = ray.get([w.get_ip.remote() for w in workers])
coordinator_address = worker_ips[0]
ray.get([w.initialize_jax.remote(coordinator_address) for w in workers])

# 4. Prepare data (8 features)
print("Preparing dataset...")


def generate_data(batch):
    return {"features": np.random.randn(len(batch["id"]), 8).astype(np.float32)}


# 64 rows total -> exactly 1 global batch of 64.
# Each of the 4 workers gets 16 rows.
ds = ray.data.range(64).map_batches(generate_data)
shards = ds.split(num_workers, equal=True)

# 5. Execute inference
print("Launching inference tasks...")
results = ray.get(
    [workers[i].run_inference.remote(shards[i]) for i in range(num_workers)]
)

# 6. Show results (Aggregated on the driver)
# results is a list of 4 numpy arrays, each of shape (16, 1)
all_preds = np.concatenate(results, axis=0)

print("\n--- Inference Results (Binary Predictions) ---")
print(f"Total predictions collected: {len(all_preds)}")
print(f"Worker result shapes: {[r.shape for r in results]}")
print("Sample predictions (first 20):")
print(all_preds[:20].flatten())

# Cleanup
slice_handle.shutdown()
```

---------

Signed-off-by: siyuanfoundation <sizhang@google.com>
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
…ject#63294)

## Description
ray-project#61630 added support for JAX to
Ray data, specifically implementing a `iter_jax_batches` util to yield
natively sharded `jax.Arrays`. This provides first-class support for
processing data within JaxTrainer workloads. This PR updates an existing
GPT-2 guide using the `JaxTrainer` to showcase how this new util could
simplify the Train code.

## Related issues
ray-project#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community data Ray Data-related issues go add ONLY when ready to merge, run all tests

7 participants