Skip to content

[Docker] Update TPU Docker image base deps#63006

Merged
elliot-barn merged 9 commits into
ray-project:masterfrom
ryanaoleary:update-tpu-dependencies
May 7, 2026
Merged

[Docker] Update TPU Docker image base deps#63006
elliot-barn merged 9 commits into
ray-project:masterfrom
ryanaoleary:update-tpu-dependencies

Conversation

@ryanaoleary

Copy link
Copy Markdown
Contributor

Description

This PR resolves #62402 by ensuring a jax[tpu] version >= 0.8.2 is installed. More context for why this change is necessary is provided in the linked issue.

This PR also adds several lightweight packages that are commonly used with jax[tpu] and libtpu. The list of updated dependencies is based on common packages seen in JAX-based frameworks like: Maxtext and tpu_inference.

Related issues

#62402

@ryanaoleary

Copy link
Copy Markdown
Contributor Author

This is what the pip-compile output for this updated package list looks like:

#
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
#    pip-compile --output-file=docker/base-deps/requirements_tpu.txt docker/base-deps/requirements_tpu.in
#
--index-url https://us-python.pkg.dev/artifact-foundry-prod/ah-3p-staging-python/simple/

absl-py==2.4.0
    # via
    #   ml-collections
    #   optax
    #   orbax-checkpoint
aiofiles==25.1.0
    # via orbax-checkpoint
aiohappyeyeballs==2.6.1
    # via aiohttp
aiohttp==3.13.5
    # via gcsfs
aiosignal==1.4.0
    # via aiohttp
attrs==26.1.0
    # via aiohttp
certifi==2026.4.22
    # via requests
cffi==2.0.0
    # via cryptography
charset-normalizer==3.4.7
    # via requests
cheroot==11.1.2
    # via xprof
cloud-tpu-diagnostics==1.1.14
    # via -r docker/base-deps/requirements_tpu.in
cryptography==47.0.0
    # via google-auth
decorator==5.2.1
    # via gcsfs
etils[epath,epy]==1.14.0
    # via
    #   orbax-checkpoint
    #   xprof
flax==0.12.7
    # via -r docker/base-deps/requirements_tpu.in
frozenlist==1.8.0
    # via
    #   aiohttp
    #   aiosignal
fsspec==2026.3.0
    # via
    #   etils
    #   gcsfs
    #   xprof
gcsfs==2026.3.0
    # via xprof
google-api-core[grpc]==2.30.3
    # via
    #   google-cloud-appengine-logging
    #   google-cloud-core
    #   google-cloud-logging
    #   google-cloud-monitoring
    #   google-cloud-storage
    #   google-cloud-storage-control
    #   ml-goodput-measurement
google-auth==2.49.2
    # via
    #   gcsfs
    #   google-api-core
    #   google-auth-oauthlib
    #   google-cloud-appengine-logging
    #   google-cloud-core
    #   google-cloud-logging
    #   google-cloud-monitoring
    #   google-cloud-storage
    #   google-cloud-storage-control
google-auth-oauthlib==1.3.1
    # via gcsfs
google-cloud-appengine-logging==1.9.0
    # via google-cloud-logging
google-cloud-audit-log==0.5.0
    # via google-cloud-logging
google-cloud-core==2.5.1
    # via
    #   google-cloud-logging
    #   google-cloud-storage
google-cloud-logging==3.15.0
    # via ml-goodput-measurement
google-cloud-monitoring==2.30.0
    # via ml-goodput-measurement
google-cloud-storage==3.10.1
    # via gcsfs
google-cloud-storage-control==1.11.0
    # via gcsfs
google-crc32c==1.8.0
    # via
    #   google-cloud-storage
    #   google-resumable-media
google-resumable-media==2.8.2
    # via google-cloud-storage
googleapis-common-protos[grpc]==1.74.0
    # via
    #   google-api-core
    #   google-cloud-audit-log
    #   grpc-google-iam-v1
    #   grpcio-status
grpc-google-iam-v1==0.14.4
    # via
    #   google-cloud-logging
    #   google-cloud-storage-control
grpcio==1.80.0
    # via
    #   google-api-core
    #   google-cloud-appengine-logging
    #   google-cloud-logging
    #   google-cloud-monitoring
    #   google-cloud-storage-control
    #   googleapis-common-protos
    #   grpc-google-iam-v1
    #   grpcio-status
    #   tpu-info
grpcio-status==1.80.0
    # via google-api-core
gviz-api==1.10.0
    # via xprof
humanize==4.15.0
    # via orbax-checkpoint
idna==3.13
    # via
    #   requests
    #   yarl
importlib-metadata==8.7.1
    # via opentelemetry-api
jaraco-functools==4.4.0
    # via cheroot
jax[tpu]==0.10.0
    # via
    #   -r docker/base-deps/requirements_tpu.in
    #   flax
    #   optax
    #   orbax-checkpoint
jaxlib==0.10.0
    # via
    #   jax
    #   optax
libtpu==0.0.40
    # via jax
markdown-it-py==4.0.0
    # via rich
markupsafe==3.0.3
    # via werkzeug
mdurl==0.1.2
    # via markdown-it-py
ml-collections==1.1.0
    # via -r docker/base-deps/requirements_tpu.in
ml-dtypes==0.5.4
    # via
    #   jax
    #   jaxlib
    #   tensorstore
ml-goodput-measurement==0.0.16
    # via -r docker/base-deps/requirements_tpu.in
more-itertools==11.0.2
    # via
    #   cheroot
    #   jaraco-functools
msgpack==1.1.2
    # via
    #   flax
    #   orbax-checkpoint
multidict==6.7.1
    # via
    #   aiohttp
    #   yarl
numpy==2.2.6
    # via
    #   -r docker/base-deps/requirements_tpu.in
    #   flax
    #   jax
    #   jaxlib
    #   ml-dtypes
    #   ml-goodput-measurement
    #   optax
    #   orbax-checkpoint
    #   scipy
    #   tensorboardx
    #   tensorstore
    #   treescope
oauthlib==3.3.1
    # via requests-oauthlib
opentelemetry-api==1.41.1
    # via google-cloud-logging
opt-einsum==3.4.0
    # via jax
optax==0.2.8
    # via
    #   -r docker/base-deps/requirements_tpu.in
    #   flax
orbax-checkpoint==0.11.36
    # via
    #   -r docker/base-deps/requirements_tpu.in
    #   flax
packaging==26.2
    # via
    #   tensorboardx
    #   tpu-info
propcache==0.4.1
    # via
    #   aiohttp
    #   yarl
proto-plus==1.27.2
    # via
    #   google-api-core
    #   google-cloud-appengine-logging
    #   google-cloud-logging
    #   google-cloud-monitoring
    #   google-cloud-storage-control
protobuf==6.33.6
    # via
    #   google-api-core
    #   google-cloud-appengine-logging
    #   google-cloud-audit-log
    #   google-cloud-logging
    #   google-cloud-monitoring
    #   google-cloud-storage-control
    #   googleapis-common-protos
    #   grpc-google-iam-v1
    #   grpcio-status
    #   orbax-checkpoint
    #   proto-plus
    #   tensorboardx
    #   tpu-info
    #   xprof
psutil==7.2.2
    # via orbax-checkpoint
pyasn1==0.6.3
    # via pyasn1-modules
pyasn1-modules==0.4.2
    # via google-auth
pycparser==3.0
    # via cffi
pygments==2.20.0
    # via rich
pyyaml==6.0.3
    # via
    #   flax
    #   ml-collections
    #   orbax-checkpoint
requests==2.33.1
    # via
    #   gcsfs
    #   google-api-core
    #   google-cloud-storage
    #   jax
    #   ml-goodput-measurement
    #   requests-oauthlib
requests-oauthlib==2.0.0
    # via google-auth-oauthlib
rich==15.0.0
    # via
    #   flax
    #   tpu-info
scipy==1.17.1
    # via
    #   jax
    #   jaxlib
    #   ml-goodput-measurement
simplejson==4.1.1
    # via orbax-checkpoint
six==1.17.0
    # via
    #   gviz-api
    #   xprof
tensorboard-plugin-profile==2.22.1
    # via -r docker/base-deps/requirements_tpu.in
tensorboardx==2.6.5
    # via ml-goodput-measurement
tensorstore==0.1.82
    # via
    #   flax
    #   orbax-checkpoint
tpu-info==0.11.0
    # via -r docker/base-deps/requirements_tpu.in
treescope==0.1.10
    # via flax
typing-extensions==4.15.0
    # via
    #   aiosignal
    #   etils
    #   flax
    #   grpcio
    #   opentelemetry-api
    #   orbax-checkpoint
urllib3==2.6.3
    # via
    #   ml-goodput-measurement
    #   requests
uvloop==0.22.1
    # via orbax-checkpoint
werkzeug==3.1.8
    # via xprof
xprof==2.22.1
    # via tensorboard-plugin-profile
yarl==1.23.0
    # via aiohttp
zipp==3.23.1
    # via
    #   etils
    #   importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the dependency configuration for TPU environments, including switching to --find-links in the depset configuration and expanding the requirements_tpu.in file with necessary JAX ecosystem packages. Feedback was provided regarding a potential version typo for JAX and the need for a lower bound on the numpy version constraint to ensure compatibility.

Comment thread docker/base-deps/requirements_tpu.in Outdated
Comment thread docker/base-deps/requirements_tpu.in Outdated
@ray-gardener ray-gardener Bot added core Issues that should be addressed in Ray Core community-contribution Contributed by the community labels Apr 29, 2026

@elliot-barn elliot-barn left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you regenerate the depsets. Run bazel run ci/raydepsets:raydepsets -- build ci/raydepsets/configs/rayimg.depsets.yaml

Comment thread docker/base-deps/requirements_tpu.in Outdated
@elliot-barn

Copy link
Copy Markdown
Collaborator

For some more context: We have a custom python dependency management tool called raydepsets. Its a uv wrapper that uses configs to determine requirements, constraints for rays dependencies and expands them for different workloads/use cases

Comment thread docker/base-deps/requirements_tpu.in Outdated
@ryanaoleary

Copy link
Copy Markdown
Contributor Author

For some more context: We have a custom python dependency management tool called raydepsets. Its a uv wrapper that uses configs to determine requirements, constraints for rays dependencies and expands them for different workloads/use cases

For some more context: We have a custom python dependency management tool called raydepsets. Its a uv wrapper that uses configs to determine requirements, constraints for rays dependencies and expands them for different workloads/use cases

Sounds good, ran the depset file and changed it to a lower bound in 9346d34, Now this version gets installed:

numpy==2.4.4
    # via
    #   -r docker/base-deps/requirements_tpu.in
    #   flax
    #   jax
    #   jaxlib
    #   ml-dtypes
    #   ml-goodput-measurement
    #   optax
    #   orbax-checkpoint
    #   scipy
    #   tensorboardx
    #   tensorstore
    #   treescope
@ryanaoleary ryanaoleary requested a review from elliot-barn April 30, 2026 01:59
Comment thread docker/base-deps/requirements_tpu.in Outdated
@ryanaoleary

ryanaoleary commented Apr 30, 2026

Copy link
Copy Markdown
Contributor Author

@elliot-barn should be updated now accordingly, I added a python dependency check in the TPU requirements file since the JAX version for Ironwood TPU requires Python >=3.11

@elliot-barn elliot-barn added the go add ONLY when ready to merge, run all tests label May 2, 2026

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Reviewed by Cursor Bugbot for commit d633d6ad0093bcd9d2d87d3189af9ec6bd7bacdd. Configure here.

Comment thread ci/raydepsets/configs/rayimg.depsets.yaml Outdated
@ryanaoleary

Copy link
Copy Markdown
Contributor Author

had to change the tpu image to build from - --python-platform=x86_64-manylinux_2_31 to resolve this failure in CI:

[2026-05-02T06:48:10Z]      hint: `libtpu` was requested with a pre-release marker (e.g., all of:
[2026-05-02T06:48:10Z]          libtpu>=0.0.32.dev0,<0.0.32
[2026-05-02T06:48:10Z]          libtpu>0.0.32,<0.0.32.1
[2026-05-02T06:48:10Z]          libtpu>0.0.32.1,<0.0.33.dev0
[2026-05-02T06:48:10Z]      ), but pre-releases weren't enabled (try: `--prerelease=allow`)

in d633d6a

@ryanaoleary

Copy link
Copy Markdown
Contributor Author

@elliot-barn wondering if there's any concern with the change here: #63006 (comment), if not I'll go ahead and merge this. Thanks again!!

@elliot-barn

Copy link
Copy Markdown
Collaborator

@elliot-barn wondering if there's any concern with the change here: #63006 (comment), if not I'll go ahead and merge this. Thanks again!!

This looks great can you fix the merge conflicts and i can merge for you

@ryanaoleary ryanaoleary force-pushed the update-tpu-dependencies branch from fa3101f to 41c0e53 Compare May 7, 2026 03:42
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
… python version

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary force-pushed the update-tpu-dependencies branch from 41c0e53 to 6b584ea Compare May 7, 2026 03:50
@ryanaoleary

Copy link
Copy Markdown
Contributor Author

@elliot-barn wondering if there's any concern with the change here: #63006 (comment), if not I'll go ahead and merge this. Thanks again!!

This looks great can you fix the merge conflicts and i can merge for you

Sounds good thanks!! rebased on master and re-ran the build depsets script

@elliot-barn elliot-barn merged commit a243e07 into ray-project:master May 7, 2026
6 checks passed
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
## Description
This PR resolves ray-project#62402 by
ensuring a `jax[tpu]` version >= `0.8.2` is installed. More context for
why this change is necessary is provided in the linked issue.

This PR also adds several lightweight packages that are commonly used
with `jax[tpu]` and `libtpu`. The list of updated dependencies is based
on common packages seen in JAX-based frameworks like:
[Maxtext](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/dependencies/requirements/requirements.txt)
and
[tpu_inference](https://github.com/vllm-project/tpu-inference/blob/main/requirements.txt).

## Related issues
ray-project#62402

---------

Signed-off-by: ryanaoleary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community core Issues that should be addressed in Ray Core go add ONLY when ready to merge, run all tests

2 participants