[data] add jax data iterator#61630
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces iter_jax_batches to Ray Data, providing a way to integrate Ray Datasets with JAX for distributed training. The changes include adding the new API to Dataset and DataIterator, along with utility functions for JAX tensor conversion and comprehensive tests. The implementation looks solid and follows existing patterns in Ray Data.
I have a couple of suggestions:
- The example in the
iter_jax_batchesdocstring indataset.pyis not runnable and could be improved for clarity. - There is some redundant code in the new
jax_util.pyfile that can be cleaned up.
Overall, this is a great addition that improves JAX integration with Ray Data.
Note: Security Review did not run due to the size of the PR.
f1f06b3 to
06624dd
Compare
23b6c9f to
03108fb
Compare
03108fb to
745306d
Compare
73f7f62 to
edccbf3
Compare
|
cc: @edoakes @richardliaw |
deeb194 to
6c97f76
Compare
| num_total_devices = jax.device_count() | ||
| num_hosts = jax.process_count() | ||
|
|
||
| if batch_size is not None and batch_size % num_total_devices != 0: |
There was a problem hiding this comment.
is it possible to use batch_size here as the per process batch_size?
I think currently the batch_size in iter_torch_batches means per process batch_size, with this change, it will cause some divergence between the iter_torch_batches vs iter_jax_batches,
There was a problem hiding this comment.
sounds good. Reverted the last commit
6c97f76 to
dd4a4e3
Compare
|
|
||
| num_local_devices = jax.local_device_count() | ||
|
|
||
| if batch_size is not None and batch_size % num_local_devices != 0: |
There was a problem hiding this comment.
IIUC, this is needed because in the util when we first calculate the physical layout with jax.devices() to view it as a flat 1D array.
I am a bit confused for this one e.g.
say we have a set up of a v5e 2x4 slice, with 2 hosts, each host has 4 chips.
we would like to express the mesh as (data, model) = (2, 4) and global_batch_size = 2, basically means, each process will get its own shard of data by ray data with 1 row in it.
since model parallel is 4, means the 4 chips on the host will see the same 1 row of data.
this config should be valid, while it will fail here since local_batch_size = 1 and num_local_devices = 4.
I think the local_batch_size should be % by data dimension or jax.processes().
Please correct me if I understand wrong!
There was a problem hiding this comment.
For the next improvement, you are right. But the current implementation, we load the data in pure data parallel form first, then reshard (Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html). So it requires the data can be sharded across devices.
In the long run, we would like to add advanced parallelism, but that would mean changing DataParallelTrainer or moving JaxTrainer away from it. We do plan to work on it, but it would take considerable more effort. In the meantime, can we add something easier to implement first?
There was a problem hiding this comment.
In that case, do you think it is acceptable to split the utils into 2 like below in the short term:
batch = iter_jax_batches(
*,
batch_size=...,
synchronize_batches=True,
...
)
from ray.data.util.jax_util import reshard_jax_batch
batch = reshard_jax_batch(batch, named_sharding)
this way we will have:
- iterator = loading/conversion
- helper = advanced reshaping
also can give us the benefit to not pass in named_sharding to iter_jax_batches()
WDYT?
There was a problem hiding this comment.
removed named_sharding argument.
For the helper function, I think the user can just call jax.device_put directly, like in the test example
for batch in ds_shard.iter_jax_batches(
batch_size=16,
drop_last=drop_last,
):
arr = jax.device_put(batch["features"], named_sharding)
assert arr.sharding == named_sharding
| def iter_jax_batches( | ||
| self, | ||
| *, | ||
| named_sharding: "jax.sharding.NamedSharding" = None, # noqa: F821 |
There was a problem hiding this comment.
I think in the long run, when we would like to express advanced parallelism, we can pass in replicate/rank metadata in for both jax/torch batches.
I am wondering if we should not add this arg first, and only assumes ddp training on jax trainer, which is the same for torch as of today, same behavior as pass in None here.
but I am also ok to keep it for now to allow user for this experimental api and maybe deprecate it later for end of March.
WDYT? @justinvyu @xinyuangui2
| ) | ||
| return | ||
|
|
||
| # Multi-host synchronization with lookahead |
There was a problem hiding this comment.
ty, I think this part makes sense to me, just wann confirm my understanding.
Basically IIUC, say we have 200 examples on 2x4 TPU V5E, 2 jax processes, global batch size = 32, local batch size = 16.
each process will have 6x16 full batches, and a 6 row batch left, while it is ok for torch to train, it is not ok for Jax, since it require same shape.
even if a 8 row batch left is valid, it will trigger a recompilation.
I think Ray data also tries to evenly distributed batches, but I agree it is not as strict as Jax here. but I am wondering how often can this recompilation happen, if ray data does a good job evenly distributed the shape and also drop_last=True, how about we set the default synchronize_batches = False, and suggest flip it to True if we see a ton of recompliation?
There was a problem hiding this comment.
currently if there is a 8 row batch left, it would keep the batch because 8 can be divided by 2x4 devices. It would drop the 6 if drop_last=True because jax requires the same shape across devices.
I don't think recompilation would be a big deal since it happens mostly likely at the end.
If we can trust Ray data to evenly distributed batches mostly, I have no problem with settingsynchronize_batches=False.
8fe6b82 to
5273fc7
Compare
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
[data] add option to skip synchronize_batches for jax iterator and sync optimization with lookahead Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
Signed-off-by: siyuanfoundation <sizhang@google.com>
5273fc7 to
5b6bcc1
Compare
Signed-off-by: siyuanfoundation <sizhang@google.com>
5b6bcc1 to
a3f9b19
Compare
## Description This PR introduces the `iter_jax_batches` API for Ray Data, enabling seamless integration between Ray Datasets and JAX architectures in distributed training scenarios. This provides first-class support for processing data within `JaxTrainer` workloads. ## Related issues ray-project#55162 ## Additional information This PR implements Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html: It loads ray data into devices using pure data parallel to be consistent with `JaxTrainer` (which is a `DataParallelTrainer`). Users can reshard it to the desired shape after getting the DDP batch. Here is an example of how to use it for jax inference ```python import jax import jax.numpy as jnp import numpy as np from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.tpu import slice_placement_group # Connect to the Ray cluster. if not ray.is_initialized(): ray.init() # 1. Reserve one v6e TPU slice (4x4, 16 chips, 4 VMs). print("Reserving TPU slice...") slice_handle = slice_placement_group( topology="4x4", accelerator_version="v6e", num_slices=1, resources_per_bundle={"TPU": 4}, ) slice_pg = slice_handle.placement_group print("Waiting for placement group to be ready...") ray.get(slice_pg.ready(), timeout=600) print("Placement group ready.") num_workers = slice_handle.num_bundles @ray.remote(num_cpus=1, resources={"TPU": 4}) class JAXInferenceWorker: def __init__(self, rank, world_size): self.rank = rank self.world_size = world_size self.ip = ray.util.get_node_ip_address() def get_ip(self): return self.ip def initialize_jax(self, coordinator_address): jax.distributed.initialize( coordinator_address=f"{coordinator_address}:1234", num_processes=self.world_size, process_id=self.rank, ) devices = jax.devices() self.mesh = Mesh(devices, axis_names=("data",)) # Initialize parameters for binary prediction (Logistic Regression) input_dim = 8 output_dim = 1 key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key) self.param_sharding = NamedSharding(self.mesh, P()) self.W = jax.device_put( jax.random.normal(k1, (input_dim, output_dim)), self.param_sharding ) self.b = jax.device_put( jax.random.normal(k2, (output_dim,)), self.param_sharding ) print(f"Worker {self.rank}: JAX initialized.") def run_inference(self, ds_shard): import jax import numpy as np # Logistic Regression Model @jax.jit def model_fn(x, W, b): logits = jnp.matmul(x, W) + b probs = jax.nn.sigmoid(logits) return (probs > 0.5).astype(jnp.int32) print(f"Worker {self.rank}: Starting inference...") # batch_size=4 rows per chip -> global batch size 64 (16 chips) batch_iterator = ds_shard.iter_jax_batches( batch_size=4, synchronize_batches=True, drop_last=True ) data_sharding = NamedSharding(self.mesh, P("data")) local_predictions = [] for i, batch in enumerate(batch_iterator): x = jax.device_put(batch["features"], data_sharding) preds = model_fn(x, self.W, self.b) # Each worker returns ONLY its local addressable data. # On a VM with 4 chips, preds.addressable_shards contains 4 shards. # We concatenate them to get the data corresponding to this host. host_shard = np.concatenate( [np.array(s.data) for s in preds.addressable_shards], axis=0 ) local_predictions.append(host_shard) if i % 2 == 0: print(f"Worker {self.rank}: Processed batch {i}...") return np.concatenate(local_predictions, axis=0) # 2. Launch workers workers = [ JAXInferenceWorker.options( scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=slice_pg) ).remote(rank=i, world_size=num_workers) for i in range(num_workers) ] # 3. Coordinate JAX setup worker_ips = ray.get([w.get_ip.remote() for w in workers]) coordinator_address = worker_ips[0] ray.get([w.initialize_jax.remote(coordinator_address) for w in workers]) # 4. Prepare data (8 features) print("Preparing dataset...") def generate_data(batch): return {"features": np.random.randn(len(batch["id"]), 8).astype(np.float32)} # 64 rows total -> exactly 1 global batch of 64. # Each of the 4 workers gets 16 rows. ds = ray.data.range(64).map_batches(generate_data) shards = ds.split(num_workers, equal=True) # 5. Execute inference print("Launching inference tasks...") results = ray.get( [workers[i].run_inference.remote(shards[i]) for i in range(num_workers)] ) # 6. Show results (Aggregated on the driver) # results is a list of 4 numpy arrays, each of shape (16, 1) all_preds = np.concatenate(results, axis=0) print("\n--- Inference Results (Binary Predictions) ---") print(f"Total predictions collected: {len(all_preds)}") print(f"Worker result shapes: {[r.shape for r in results]}") print("Sample predictions (first 20):") print(all_preds[:20].flatten()) # Cleanup slice_handle.shutdown() ``` --------- Signed-off-by: siyuanfoundation <sizhang@google.com>
## Description #61630 added support for JAX to Ray data, specifically implementing a `iter_jax_batches` util to yield natively sharded `jax.Arrays`. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the `JaxTrainer` to showcase how this new util could simplify the Train code. ## Related issues #55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
## Description This PR introduces the `iter_jax_batches` API for Ray Data, enabling seamless integration between Ray Datasets and JAX architectures in distributed training scenarios. This provides first-class support for processing data within `JaxTrainer` workloads. ## Related issues ray-project#55162 ## Additional information This PR implements Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html: It loads ray data into devices using pure data parallel to be consistent with `JaxTrainer` (which is a `DataParallelTrainer`). Users can reshard it to the desired shape after getting the DDP batch. Here is an example of how to use it for jax inference ```python import jax import jax.numpy as jnp import numpy as np from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.tpu import slice_placement_group # Connect to the Ray cluster. if not ray.is_initialized(): ray.init() # 1. Reserve one v6e TPU slice (4x4, 16 chips, 4 VMs). print("Reserving TPU slice...") slice_handle = slice_placement_group( topology="4x4", accelerator_version="v6e", num_slices=1, resources_per_bundle={"TPU": 4}, ) slice_pg = slice_handle.placement_group print("Waiting for placement group to be ready...") ray.get(slice_pg.ready(), timeout=600) print("Placement group ready.") num_workers = slice_handle.num_bundles @ray.remote(num_cpus=1, resources={"TPU": 4}) class JAXInferenceWorker: def __init__(self, rank, world_size): self.rank = rank self.world_size = world_size self.ip = ray.util.get_node_ip_address() def get_ip(self): return self.ip def initialize_jax(self, coordinator_address): jax.distributed.initialize( coordinator_address=f"{coordinator_address}:1234", num_processes=self.world_size, process_id=self.rank, ) devices = jax.devices() self.mesh = Mesh(devices, axis_names=("data",)) # Initialize parameters for binary prediction (Logistic Regression) input_dim = 8 output_dim = 1 key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key) self.param_sharding = NamedSharding(self.mesh, P()) self.W = jax.device_put( jax.random.normal(k1, (input_dim, output_dim)), self.param_sharding ) self.b = jax.device_put( jax.random.normal(k2, (output_dim,)), self.param_sharding ) print(f"Worker {self.rank}: JAX initialized.") def run_inference(self, ds_shard): import jax import numpy as np # Logistic Regression Model @jax.jit def model_fn(x, W, b): logits = jnp.matmul(x, W) + b probs = jax.nn.sigmoid(logits) return (probs > 0.5).astype(jnp.int32) print(f"Worker {self.rank}: Starting inference...") # batch_size=4 rows per chip -> global batch size 64 (16 chips) batch_iterator = ds_shard.iter_jax_batches( batch_size=4, synchronize_batches=True, drop_last=True ) data_sharding = NamedSharding(self.mesh, P("data")) local_predictions = [] for i, batch in enumerate(batch_iterator): x = jax.device_put(batch["features"], data_sharding) preds = model_fn(x, self.W, self.b) # Each worker returns ONLY its local addressable data. # On a VM with 4 chips, preds.addressable_shards contains 4 shards. # We concatenate them to get the data corresponding to this host. host_shard = np.concatenate( [np.array(s.data) for s in preds.addressable_shards], axis=0 ) local_predictions.append(host_shard) if i % 2 == 0: print(f"Worker {self.rank}: Processed batch {i}...") return np.concatenate(local_predictions, axis=0) # 2. Launch workers workers = [ JAXInferenceWorker.options( scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=slice_pg) ).remote(rank=i, world_size=num_workers) for i in range(num_workers) ] # 3. Coordinate JAX setup worker_ips = ray.get([w.get_ip.remote() for w in workers]) coordinator_address = worker_ips[0] ray.get([w.initialize_jax.remote(coordinator_address) for w in workers]) # 4. Prepare data (8 features) print("Preparing dataset...") def generate_data(batch): return {"features": np.random.randn(len(batch["id"]), 8).astype(np.float32)} # 64 rows total -> exactly 1 global batch of 64. # Each of the 4 workers gets 16 rows. ds = ray.data.range(64).map_batches(generate_data) shards = ds.split(num_workers, equal=True) # 5. Execute inference print("Launching inference tasks...") results = ray.get( [workers[i].run_inference.remote(shards[i]) for i in range(num_workers)] ) # 6. Show results (Aggregated on the driver) # results is a list of 4 numpy arrays, each of shape (16, 1) all_preds = np.concatenate(results, axis=0) print("\n--- Inference Results (Binary Predictions) ---") print(f"Total predictions collected: {len(all_preds)}") print(f"Worker result shapes: {[r.shape for r in results]}") print("Sample predictions (first 20):") print(all_preds[:20].flatten()) # Cleanup slice_handle.shutdown() ``` --------- Signed-off-by: siyuanfoundation <sizhang@google.com>
…ject#63294) ## Description ray-project#61630 added support for JAX to Ray data, specifically implementing a `iter_jax_batches` util to yield natively sharded `jax.Arrays`. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the `JaxTrainer` to showcase how this new util could simplify the Train code. ## Related issues ray-project#55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Description
This PR introduces the
iter_jax_batchesAPI for Ray Data, enabling seamless integration between Ray Datasets and JAX architectures in distributed training scenarios. This provides first-class support for processing data withinJaxTrainerworkloads.Related issues
#55162
Additional information
This PR implements Option 4 in https://docs.jax.dev/en/latest/distributed_data_loading.html:
It loads ray data into devices using pure data parallel to be consistent with
JaxTrainer(which is aDataParallelTrainer).Users can reshard it to the desired shape after getting the DDP batch.
Here is an example of how to use it for jax inference