Skip to content

[Data] Optimize local shuffle with incremental index method and configurable compaction threshold#62539

Merged
richardliaw merged 1 commit into
ray-project:masterfrom
xinyuangui2:xgui/optimize-local-shuffle
Apr 15, 2026
Merged

[Data] Optimize local shuffle with incremental index method and configurable compaction threshold#62539
richardliaw merged 1 commit into
ray-project:masterfrom
xinyuangui2:xgui/optimize-local-shuffle

Conversation

@xinyuangui2

@xinyuangui2 xinyuangui2 commented Apr 12, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Replace the full-buffer shuffle in ShufflingBatcher with an incremental index method that generates a permutation array once per compaction and gathers small slices via take per batch. This uses less memory (no second copy of the full buffer) and has smoother per-batch latency that is easier to hide behind prefetch threads.

Benchmark

Setup

  • Dataset: ray.data.range_tensor(81_920_000, shape=(512,)) (~4KB/row, ~320 GB total)
  • Model: Linear(512, 10) -- trivial, measures data pipeline not model
  • Batch size: 4096 per worker
  • Workers: 4 GPU workers
  • Steps: 200 total, first 100 warmup (excluded from steady throughput)
  • Metric: Steady-state throughput (rows/sec) after warmup
  • Environment: RAY_DATA_MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS=1

Results

Buffer Size With PR (rows/s) Pre-PR (rows/s) Speedup
no shuffle (baseline) 1,759,282 1,762,696 1.0x
1 GB (244,140 rows) 225,181 67,659 3.3x
2 GB (488,281 rows) 220,644 66,801 3.3x
3 GB (732,421 rows) 153,256 79,889 1.9x

Key findings

  1. The PR is 1.9-3.3x faster across all tested buffer sizes.
  2. Both methods show significant overhead vs no-shuffle baseline (best case is 13%
    of baseline). This overhead comes from the compaction cycle itself -- building,
    concatenating, and reshuffling the buffer.

Script

"""Benchmark: local buffer shuffle throughput on synthetic workload.

Measures steady-state training throughput (rows/sec) with Ray Data's local
buffer shuffle at different buffer sizes.

Setup:
    - Dataset: ray.data.range_tensor(81_920_000, shape=(512,))  (~4KB/row)
    - Model: Linear(512, 10) -- trivial, measures data pipeline not model
    - Batch size: 4096 per worker
    - 4 GPU workers, 200 steps per run, first 100 warmup

Usage:
    python benchmark_local_shuffle.py                # full benchmark
    python benchmark_local_shuffle.py --reps 1       # quick single pass
    python benchmark_local_shuffle.py --num_workers 2
"""

import json
import logging
import os
import statistics
import time
import uuid

import ray
import ray.data
import ray.train
import ray.train.torch
import torch
from ray.train.torch import TorchTrainer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -- Constants --
NUM_ROWS = 81_920_000
TENSOR_DIM = 512
ROW_BYTES = TENSOR_DIM * 8  # 4096 bytes per row (int64)
BATCH_SIZE = 4096
MAX_STEPS = 200
WARMUP_STEPS = 100
NUM_WORKERS_DEFAULT = 4
METRICS_DIR = "/mnt/cluster_storage"
WAIT_BETWEEN_RUNS = 30


def gb_to_rows(gb: float) -> int:
    return int(gb * 1e9 / ROW_BYTES)


BUFFER_CONFIGS = [
    {"label": "no_shuffle", "buffer_rows": 0},
    {"label": "1GB_buffer", "buffer_rows": gb_to_rows(1)},
    {"label": "2GB_buffer", "buffer_rows": gb_to_rows(2)},
    {"label": "3GB_buffer", "buffer_rows": gb_to_rows(3)},
]


def train_fn(config):
    warmup_steps = config["warmup_steps"]
    max_steps = config["max_steps"]
    metrics_path = config["metrics_path"]
    buffer_size = config["buffer_size"]

    device = ray.train.torch.get_device()
    model = ray.train.torch.prepare_model(torch.nn.Linear(TENSOR_DIM, 10))
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    ds_iter = ray.train.get_dataset_shard("train")
    iter_kwargs = dict(
        batch_size=BATCH_SIZE,
        prefetch_batches=4,
        drop_last=True,
    )
    if buffer_size > 0:
        iter_kwargs["local_shuffle_buffer_size"] = buffer_size
    dataloader = ds_iter.iter_torch_batches(**iter_kwargs)

    world_size = ray.train.get_context().get_world_size()
    global_batch_size = BATCH_SIZE * world_size

    total_rows = 0
    steady_start = None
    steady_rows = 0
    step = 0

    t0 = time.perf_counter()
    for batch in dataloader:
        data = batch["data"].float().to(device)
        labels = data[:, 0].long() % 10

        optimizer.zero_grad()
        loss = loss_fn(model(data), labels)
        loss.backward()
        optimizer.step()

        total_rows += global_batch_size
        step += 1

        if step > warmup_steps:
            if steady_start is None:
                steady_start = time.perf_counter()
            steady_rows += global_batch_size

        if step % 100 == 0 and ray.train.get_context().get_world_rank() == 0:
            elapsed_so_far = time.perf_counter() - t0
            logger.info(
                f"step={step}  throughput={total_rows / elapsed_so_far:,.0f} rows/s  "
                f"elapsed={elapsed_so_far:.1f}s"
            )

        if max_steps > 0 and step >= max_steps:
            break

    elapsed = time.perf_counter() - t0
    steady_elapsed = (time.perf_counter() - steady_start) if steady_start else 0

    metrics = {
        "throughput": total_rows / elapsed if elapsed > 0 else 0,
        "steady_throughput": steady_rows / steady_elapsed if steady_elapsed > 0 else 0,
        "elapsed": elapsed,
        "steps": step,
    }
    ray.train.report(metrics)

    if ray.train.get_context().get_world_rank() == 0:
        with open(metrics_path, "w") as f:
            json.dump(metrics, f)


def run_once(buffer_rows, num_workers):
    ds = ray.data.range_tensor(NUM_ROWS, shape=(TENSOR_DIM,))
    metrics_path = f"{METRICS_DIR}/bench_{uuid.uuid4().hex[:8]}.json"

    trainer = TorchTrainer(
        train_loop_per_worker=train_fn,
        train_loop_config={
            "warmup_steps": WARMUP_STEPS,
            "max_steps": MAX_STEPS,
            "metrics_path": metrics_path,
            "buffer_size": buffer_rows,
        },
        scaling_config=ray.train.ScalingConfig(
            num_workers=num_workers,
            use_gpu=True,
        ),
        datasets={"train": ds},
    )
    trainer.fit()

    metrics = {}
    for _ in range(10):
        if os.path.exists(metrics_path):
            try:
                with open(metrics_path) as f:
                    metrics = json.load(f)
                os.remove(metrics_path)
                break
            except Exception:
                pass
        time.sleep(1)
    return metrics

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new "index" shuffle method to the ShufflingBatcher as an alternative to the default "full" method, aiming to reduce memory overhead by shuffling indices and gathering batches via take. It also adds a shuffle_buffer_compaction_threshold parameter to tune the frequency of compaction and re-shuffling. Review feedback suggests optimizing the random number generator usage by avoiding repeated RandomState initializations and cleaning up redundant BlockAccessor calls for better efficiency.

Comment thread python/ray/data/_internal/batcher.py Outdated
Comment thread python/ray/data/_internal/batcher.py Outdated
@xinyuangui2 xinyuangui2 changed the title optimize local shuffle Apr 13, 2026
Comment thread python/ray/data/_internal/batcher.py Outdated
Comment thread python/ray/data/_internal/batcher.py Outdated
Comment thread python/ray/data/tests/test_batcher.py
Comment thread python/ray/data/iterator.py
Comment thread python/ray/data/tests/test_batcher.py Outdated
@xinyuangui2 xinyuangui2 marked this pull request as ready for review April 13, 2026 18:24
@xinyuangui2 xinyuangui2 requested a review from a team as a code owner April 13, 2026 18:24
Comment thread python/ray/data/_internal/batcher.py
@ray-gardener ray-gardener Bot added the data Ray Data-related issues label Apr 13, 2026

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

Reviewed by Cursor Bugbot for commit 49fa9ed61e392ab6901a9115d2323f1fe521c9c2. Configure here.

Comment thread python/ray/data/context.py Outdated
@xinyuangui2 xinyuangui2 added the go add ONLY when ready to merge, run all tests label Apr 14, 2026
@xinyuangui2 xinyuangui2 requested a review from richardliaw April 14, 2026 18:17
Comment thread python/ray/data/tests/test_batcher.py Outdated
Comment thread python/ray/data/tests/test_batcher.py Outdated

@richardliaw richardliaw left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit comments, pls address

@xinyuangui2 xinyuangui2 force-pushed the xgui/optimize-local-shuffle branch from 19ee3f3 to 796ed5e Compare April 15, 2026 04:32
Replace the full-buffer shuffle in ShufflingBatcher with an incremental
index method that generates a permutation array once per compaction and
gathers small slices via `take` per batch.

- Memory-efficient: data buffer kept as-is, only a lightweight int64
  index array allocated on top.
- Smooth per-batch latency: each `take` is a small slice, easy to hide
  behind prefetch threads.
- Use a single np.random.default_rng instance instead of creating
  RandomState per compaction.
- Add SHUFFLE_BUFFER_COMPACTION_THRESHOLD (0.5) constant to control
  when re-shuffling is triggered.
- Fix has_batch() to require num_rows >= batch_size in streaming mode,
  preventing partial batches mid-stream.
- Add tests: incremental index validated against full-shuffle reference,
  partial batch mid-stream regression test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@xinyuangui2 xinyuangui2 force-pushed the xgui/optimize-local-shuffle branch from 796ed5e to 9b6dfce Compare April 15, 2026 04:40
@richardliaw richardliaw merged commit bfa2f1c into ray-project:master Apr 15, 2026
6 checks passed
HLDKNotFound pushed a commit to chichic21039/ray that referenced this pull request Apr 22, 2026
…gurable compaction threshold (ray-project#62539)

## Summary
- Replace the full-buffer shuffle in `ShufflingBatcher` with an
incremental index method that generates a permutation array once per
compaction and gathers small slices via `take` per batch. This uses less
memory (no second copy of the full buffer) and has smoother per-batch
latency that is easier to hide behind prefetch threads.
- Add `local_shuffle_compaction_threshold` to `DataContext` (default
0.5, env var `RAY_DATA_LOCAL_SHUFFLE_COMPACTION_THRESHOLD`) to control
how aggressively compaction/re-shuffling is triggered. Lower values
reduce shuffle frequency at the cost of randomness.

## Benchmark

### Setup

- **Dataset**: `ray.data.range_tensor(81_920_000, shape=(512,))`
(~4KB/row, ~320 GB total)
- **Model**: `Linear(512, 10)` -- trivial, measures data pipeline not
model
- **Batch size**: 4096 per worker
- **Workers**: 4 GPU workers
- **Steps**: 200 total, first 100 warmup (excluded from steady
throughput)
- **Metric**: Steady-state throughput (rows/sec) after warmup
- **Environment**: `RAY_DATA_MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS=1`

### Results

| Buffer Size | With PR (rows/s) | Pre-PR (rows/s) | Speedup |
|---|---|---|---|
| **no shuffle** (baseline) | 1,759,282 | 1,762,696 | 1.0x |
| **1 GB** (244,140 rows) | 225,181 | 67,659 | **3.3x** |
| **2 GB** (488,281 rows) | 220,644 | 66,801 | **3.3x** |
| **3 GB** (732,421 rows) | 153,256 | 79,889 | **1.9x** |


## Key findings

1. **The PR is 1.9-3.3x faster** across all tested buffer sizes.
2. Both methods show significant overhead vs no-shuffle baseline (best
case is 13%
of baseline). This overhead comes from the compaction cycle itself --
building,
   concatenating, and reshuffling the buffer.

## Script

```python
"""Benchmark: local buffer shuffle throughput on synthetic workload.

Measures steady-state training throughput (rows/sec) with Ray Data's local
buffer shuffle at different buffer sizes.

Setup:
    - Dataset: ray.data.range_tensor(81_920_000, shape=(512,))  (~4KB/row)
    - Model: Linear(512, 10) -- trivial, measures data pipeline not model
    - Batch size: 4096 per worker
    - 4 GPU workers, 200 steps per run, first 100 warmup

Usage:
    python benchmark_local_shuffle.py                # full benchmark
    python benchmark_local_shuffle.py --reps 1       # quick single pass
    python benchmark_local_shuffle.py --num_workers 2
"""

import json
import logging
import os
import statistics
import time
import uuid

import ray
import ray.data
import ray.train
import ray.train.torch
import torch
from ray.train.torch import TorchTrainer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -- Constants --
NUM_ROWS = 81_920_000
TENSOR_DIM = 512
ROW_BYTES = TENSOR_DIM * 8  # 4096 bytes per row (int64)
BATCH_SIZE = 4096
MAX_STEPS = 200
WARMUP_STEPS = 100
NUM_WORKERS_DEFAULT = 4
METRICS_DIR = "/mnt/cluster_storage"
WAIT_BETWEEN_RUNS = 30


def gb_to_rows(gb: float) -> int:
    return int(gb * 1e9 / ROW_BYTES)


BUFFER_CONFIGS = [
    {"label": "no_shuffle", "buffer_rows": 0},
    {"label": "1GB_buffer", "buffer_rows": gb_to_rows(1)},
    {"label": "2GB_buffer", "buffer_rows": gb_to_rows(2)},
    {"label": "3GB_buffer", "buffer_rows": gb_to_rows(3)},
]


def train_fn(config):
    warmup_steps = config["warmup_steps"]
    max_steps = config["max_steps"]
    metrics_path = config["metrics_path"]
    buffer_size = config["buffer_size"]

    device = ray.train.torch.get_device()
    model = ray.train.torch.prepare_model(torch.nn.Linear(TENSOR_DIM, 10))
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    ds_iter = ray.train.get_dataset_shard("train")
    iter_kwargs = dict(
        batch_size=BATCH_SIZE,
        prefetch_batches=4,
        drop_last=True,
    )
    if buffer_size > 0:
        iter_kwargs["local_shuffle_buffer_size"] = buffer_size
    dataloader = ds_iter.iter_torch_batches(**iter_kwargs)

    world_size = ray.train.get_context().get_world_size()
    global_batch_size = BATCH_SIZE * world_size

    total_rows = 0
    steady_start = None
    steady_rows = 0
    step = 0

    t0 = time.perf_counter()
    for batch in dataloader:
        data = batch["data"].float().to(device)
        labels = data[:, 0].long() % 10

        optimizer.zero_grad()
        loss = loss_fn(model(data), labels)
        loss.backward()
        optimizer.step()

        total_rows += global_batch_size
        step += 1

        if step > warmup_steps:
            if steady_start is None:
                steady_start = time.perf_counter()
            steady_rows += global_batch_size

        if step % 100 == 0 and ray.train.get_context().get_world_rank() == 0:
            elapsed_so_far = time.perf_counter() - t0
            logger.info(
                f"step={step}  throughput={total_rows / elapsed_so_far:,.0f} rows/s  "
                f"elapsed={elapsed_so_far:.1f}s"
            )

        if max_steps > 0 and step >= max_steps:
            break

    elapsed = time.perf_counter() - t0
    steady_elapsed = (time.perf_counter() - steady_start) if steady_start else 0

    metrics = {
        "throughput": total_rows / elapsed if elapsed > 0 else 0,
        "steady_throughput": steady_rows / steady_elapsed if steady_elapsed > 0 else 0,
        "elapsed": elapsed,
        "steps": step,
    }
    ray.train.report(metrics)

    if ray.train.get_context().get_world_rank() == 0:
        with open(metrics_path, "w") as f:
            json.dump(metrics, f)


def run_once(buffer_rows, num_workers):
    ds = ray.data.range_tensor(NUM_ROWS, shape=(TENSOR_DIM,))
    metrics_path = f"{METRICS_DIR}/bench_{uuid.uuid4().hex[:8]}.json"

    trainer = TorchTrainer(
        train_loop_per_worker=train_fn,
        train_loop_config={
            "warmup_steps": WARMUP_STEPS,
            "max_steps": MAX_STEPS,
            "metrics_path": metrics_path,
            "buffer_size": buffer_rows,
        },
        scaling_config=ray.train.ScalingConfig(
            num_workers=num_workers,
            use_gpu=True,
        ),
        datasets={"train": ds},
    )
    trainer.fit()

    metrics = {}
    for _ in range(10):
        if os.path.exists(metrics_path):
            try:
                with open(metrics_path) as f:
                    metrics = json.load(f)
                os.remove(metrics_path)
                break
            except Exception:
                pass
        time.sleep(1)
    return metrics
```

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
…gurable compaction threshold (ray-project#62539)

## Summary
- Replace the full-buffer shuffle in `ShufflingBatcher` with an
incremental index method that generates a permutation array once per
compaction and gathers small slices via `take` per batch. This uses less
memory (no second copy of the full buffer) and has smoother per-batch
latency that is easier to hide behind prefetch threads.
- Add `local_shuffle_compaction_threshold` to `DataContext` (default
0.5, env var `RAY_DATA_LOCAL_SHUFFLE_COMPACTION_THRESHOLD`) to control
how aggressively compaction/re-shuffling is triggered. Lower values
reduce shuffle frequency at the cost of randomness.

## Benchmark

### Setup

- **Dataset**: `ray.data.range_tensor(81_920_000, shape=(512,))`
(~4KB/row, ~320 GB total)
- **Model**: `Linear(512, 10)` -- trivial, measures data pipeline not
model
- **Batch size**: 4096 per worker
- **Workers**: 4 GPU workers
- **Steps**: 200 total, first 100 warmup (excluded from steady
throughput)
- **Metric**: Steady-state throughput (rows/sec) after warmup
- **Environment**: `RAY_DATA_MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS=1`

### Results

| Buffer Size | With PR (rows/s) | Pre-PR (rows/s) | Speedup |
|---|---|---|---|
| **no shuffle** (baseline) | 1,759,282 | 1,762,696 | 1.0x |
| **1 GB** (244,140 rows) | 225,181 | 67,659 | **3.3x** |
| **2 GB** (488,281 rows) | 220,644 | 66,801 | **3.3x** |
| **3 GB** (732,421 rows) | 153,256 | 79,889 | **1.9x** |


## Key findings

1. **The PR is 1.9-3.3x faster** across all tested buffer sizes.
2. Both methods show significant overhead vs no-shuffle baseline (best
case is 13%
of baseline). This overhead comes from the compaction cycle itself --
building,
   concatenating, and reshuffling the buffer.

## Script

```python
"""Benchmark: local buffer shuffle throughput on synthetic workload.

Measures steady-state training throughput (rows/sec) with Ray Data's local
buffer shuffle at different buffer sizes.

Setup:
    - Dataset: ray.data.range_tensor(81_920_000, shape=(512,))  (~4KB/row)
    - Model: Linear(512, 10) -- trivial, measures data pipeline not model
    - Batch size: 4096 per worker
    - 4 GPU workers, 200 steps per run, first 100 warmup

Usage:
    python benchmark_local_shuffle.py                # full benchmark
    python benchmark_local_shuffle.py --reps 1       # quick single pass
    python benchmark_local_shuffle.py --num_workers 2
"""

import json
import logging
import os
import statistics
import time
import uuid

import ray
import ray.data
import ray.train
import ray.train.torch
import torch
from ray.train.torch import TorchTrainer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -- Constants --
NUM_ROWS = 81_920_000
TENSOR_DIM = 512
ROW_BYTES = TENSOR_DIM * 8  # 4096 bytes per row (int64)
BATCH_SIZE = 4096
MAX_STEPS = 200
WARMUP_STEPS = 100
NUM_WORKERS_DEFAULT = 4
METRICS_DIR = "/mnt/cluster_storage"
WAIT_BETWEEN_RUNS = 30


def gb_to_rows(gb: float) -> int:
    return int(gb * 1e9 / ROW_BYTES)


BUFFER_CONFIGS = [
    {"label": "no_shuffle", "buffer_rows": 0},
    {"label": "1GB_buffer", "buffer_rows": gb_to_rows(1)},
    {"label": "2GB_buffer", "buffer_rows": gb_to_rows(2)},
    {"label": "3GB_buffer", "buffer_rows": gb_to_rows(3)},
]


def train_fn(config):
    warmup_steps = config["warmup_steps"]
    max_steps = config["max_steps"]
    metrics_path = config["metrics_path"]
    buffer_size = config["buffer_size"]

    device = ray.train.torch.get_device()
    model = ray.train.torch.prepare_model(torch.nn.Linear(TENSOR_DIM, 10))
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    ds_iter = ray.train.get_dataset_shard("train")
    iter_kwargs = dict(
        batch_size=BATCH_SIZE,
        prefetch_batches=4,
        drop_last=True,
    )
    if buffer_size > 0:
        iter_kwargs["local_shuffle_buffer_size"] = buffer_size
    dataloader = ds_iter.iter_torch_batches(**iter_kwargs)

    world_size = ray.train.get_context().get_world_size()
    global_batch_size = BATCH_SIZE * world_size

    total_rows = 0
    steady_start = None
    steady_rows = 0
    step = 0

    t0 = time.perf_counter()
    for batch in dataloader:
        data = batch["data"].float().to(device)
        labels = data[:, 0].long() % 10

        optimizer.zero_grad()
        loss = loss_fn(model(data), labels)
        loss.backward()
        optimizer.step()

        total_rows += global_batch_size
        step += 1

        if step > warmup_steps:
            if steady_start is None:
                steady_start = time.perf_counter()
            steady_rows += global_batch_size

        if step % 100 == 0 and ray.train.get_context().get_world_rank() == 0:
            elapsed_so_far = time.perf_counter() - t0
            logger.info(
                f"step={step}  throughput={total_rows / elapsed_so_far:,.0f} rows/s  "
                f"elapsed={elapsed_so_far:.1f}s"
            )

        if max_steps > 0 and step >= max_steps:
            break

    elapsed = time.perf_counter() - t0
    steady_elapsed = (time.perf_counter() - steady_start) if steady_start else 0

    metrics = {
        "throughput": total_rows / elapsed if elapsed > 0 else 0,
        "steady_throughput": steady_rows / steady_elapsed if steady_elapsed > 0 else 0,
        "elapsed": elapsed,
        "steps": step,
    }
    ray.train.report(metrics)

    if ray.train.get_context().get_world_rank() == 0:
        with open(metrics_path, "w") as f:
            json.dump(metrics, f)


def run_once(buffer_rows, num_workers):
    ds = ray.data.range_tensor(NUM_ROWS, shape=(TENSOR_DIM,))
    metrics_path = f"{METRICS_DIR}/bench_{uuid.uuid4().hex[:8]}.json"

    trainer = TorchTrainer(
        train_loop_per_worker=train_fn,
        train_loop_config={
            "warmup_steps": WARMUP_STEPS,
            "max_steps": MAX_STEPS,
            "metrics_path": metrics_path,
            "buffer_size": buffer_rows,
        },
        scaling_config=ray.train.ScalingConfig(
            num_workers=num_workers,
            use_gpu=True,
        ),
        datasets={"train": ds},
    )
    trainer.fit()

    metrics = {}
    for _ in range(10):
        if os.path.exists(metrics_path):
            try:
                with open(metrics_path) as f:
                    metrics = json.load(f)
                os.remove(metrics_path)
                break
            except Exception:
                pass
        time.sleep(1)
    return metrics
```

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Ray Data-related issues go add ONLY when ready to merge, run all tests

2 participants