-
Notifications
You must be signed in to change notification settings - Fork 3.7k
131 lines (124 loc) · 5.19 KB
/
Copy pathci-build.yaml
File metadata and controls
131 lines (124 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
name: Pytest CPU
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# Don't cancel in-progress jobs for main branches.
cancel-in-progress: ${{ github.ref != 'main' }}
env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
UV_MANAGED_PYTHON: true # Make sure `uv` uses its own Python installations
PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
# Print cache hits/misses for setup-uv-python Python and packages,
# as well as timing, in summaries. For packages, time taken has to be echo'ed
# separately in the workflow step.
CI_UV_DEBUG: true
defaults:
run:
shell: bash
jobs:
tests:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
name: "py${{ matrix.python-version }}, x64=${{ matrix.enable-x64 }}, ubuntu-22.04"
runs-on: linux-x86-n4-32
container:
image: index.docker.io/library/ubuntu@sha256:4e0171b9275e12d375863f2b3ae9ce00a4c53ddda176bd55868df97ac6f21a6e # ratchet:ubuntu:22.04
timeout-minutes: 60
strategy:
matrix:
# Test the oldest and newest supported Python versions here.
include:
- name-prefix: "with 3.12"
python-version: "3.12"
enable-x64: 1
prng-upgrade: 1
num_generated_cases: 1
- name-prefix: "with 3.14"
python-version: "3.14"
enable-x64: 0
prng-upgrade: 0
num_generated_cases: 1
steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
persist-credentials: false
- name: Set up Python ${{ matrix.python-version }}
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" '.[minimum-jaxlib]' -r 'build/test-requirements.txt'
echo "### build dependency install (${{ matrix.python-version }})" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Run tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
$PYTHON_BIN -m pytest -n auto --tb=short --maxfail=20 tests examples
ffi:
if: github.repository_owner == 'jax-ml'
name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30
steps:
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with:
persist-credentials: false
- name: Set up Python 3.12
id: setup_python
uses: google-ml-infra/actions/setup-uv-python@a1817800cb84c752378772c2a02781cf8309a399
with:
python-version: '3.12'
- name: Install JAX and FFI example
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
# We test building using GCC instead of clang. All other JAX builds use
# clang, but it is useful to make sure that FFI users can compile using
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
run: |
PKG_INSTALL_START_TIME=$(date +%s)
uv pip install --python "$PYTHON_BIN" '.[cuda12]'
uv pip install --python "$PYTHON_BIN" './examples/ffi[test]'
echo "### ffi dependency install" >> "$GITHUB_STEP_SUMMARY"
echo "- duration: $(( $(date +%s) - PKG_INSTALL_START_TIME ))s" >> "$GITHUB_STEP_SUMMARY"
- name: Run CPU tests
run: $PYTHON_BIN -m pytest examples/ffi/tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}
JAX_PLATFORM_NAME: cpu
- name: Run GPU tests
run: $PYTHON_BIN -m pytest examples/ffi/tests
env:
PYTHON_BIN: ${{ steps.setup_python.outputs.python-bin }}