forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 6
76 lines (68 loc) · 2.5 KB
/
cuda-ci.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
name: CUDA GPU
# Only run this workflow when a Pull Request is labeled with the
# 'CUDA CI' label.
on:
pull_request:
types:
- labeled
jobs:
build_wheel:
if: contains(github.event.pull_request.labels.*.name, 'CUDA CI')
runs-on: "ubuntu-latest"
name: Build wheel for Pull Request
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_BUILD: cp312-manylinux_x86_64
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014
CIBW_BUILD_VERBOSITY: 1
CIBW_ARCHS: x86_64
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels
path: ./wheelhouse/*.whl
tests:
if: contains(github.event.pull_request.labels.*.name, 'CUDA CI')
needs: [build_wheel]
runs-on:
group: cuda-gpu-runner-group
# Set this high enough so that the tests can comforatble run. We set a
# timeout to make abusing this workflow less attractive.
timeout-minutes: 20
name: Run Array API unit tests
steps:
- uses: actions/download-artifact@v4
with:
pattern: cibw-wheels
path: ~/dist
- uses: actions/setup-python@v5
with:
# XXX: The 3.12.4 release of Python on GitHub Actions is corrupted:
# https://github.com/actions/setup-python/issues/886
python-version: '3.12.3'
- name: Checkout main repository
uses: actions/checkout@v4
- name: Cache conda environment
id: cache-conda
uses: actions/cache@v4
with:
path: ~/conda
key: ${{ runner.os }}-build-${{ hashFiles('build_tools/github/create_gpu_environment.sh') }}-${{ hashFiles('build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock') }}
- name: Install miniforge
if: ${{ steps.cache-conda.outputs.cache-hit != 'true' }}
run: bash build_tools/github/create_gpu_environment.sh
- name: Install scikit-learn
run: |
source "${HOME}/conda/etc/profile.d/conda.sh"
conda activate sklearn
pip install ~/dist/cibw-wheels/$(ls ~/dist/cibw-wheels)
- name: Run array API tests
run: |
source "${HOME}/conda/etc/profile.d/conda.sh"
conda activate sklearn
python -c "import sklearn; sklearn.show_versions()"
SCIPY_ARRAY_API=1 pytest --pyargs sklearn -k 'array_api'
# Run in /home/runner to not load sklearn from the checkout repo
working-directory: /home/runner