File tree Expand file tree Collapse file tree 3 files changed +62
-4
lines changed Expand file tree Collapse file tree 3 files changed +62
-4
lines changed Original file line number Diff line number Diff line change 1
1
#! /bin/bash
2
2
set -ex
3
+
4
+ # Parse command line arguments
5
+ USE_CPU_BACKEND=false
6
+ while [[ $# -gt 0 ]]; do
7
+ case $1 in
8
+ --cpu)
9
+ USE_CPU_BACKEND=true
10
+ shift
11
+ ;;
12
+ * )
13
+ echo " Unknown option: $1 "
14
+ exit 1
15
+ ;;
16
+ esac
17
+ done
18
+
3
19
(
4
20
mkdir -p /tmp/$USER
5
21
pushd /tmp/$USER
6
22
pip uninstall -y triton pytorch-triton || true
7
23
rm -rf triton/ || true
8
- git clone https://github.com/triton-lang/triton.git # install triton latest main
24
+
25
+ # Clone the appropriate repository based on backend
26
+ if [ " $USE_CPU_BACKEND " = true ]; then
27
+ # Install triton-cpu from triton-cpu repository
28
+ git clone --recursive https://github.com/triton-lang/triton-cpu.git triton
29
+ else
30
+ # Install triton from main repository for GPU backend
31
+ git clone https://github.com/triton-lang/triton.git triton
32
+ fi
33
+
34
+ # Shared build process for both backends
9
35
(
10
36
pushd triton/
11
37
conda config --set channel_priority strict
@@ -14,10 +40,14 @@ set -ex
14
40
conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13
15
41
pip install -r python/requirements.txt
16
42
# Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines
17
- MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
43
+ if [ " $USE_CPU_BACKEND " = true ]; then
44
+ MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install -e python # install to conda site-packages/ folder
45
+ else
46
+ MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
47
+ fi
18
48
popd
19
49
)
20
- rm -rf triton/
50
+ # rm -rf triton/
21
51
popd
22
52
)
23
53
exit 0
Original file line number Diff line number Diff line change 39
39
./.github/scripts/install_triton.sh
40
40
pip install -r requirements.txt
41
41
python -m unittest discover -s test/ -p "*.py" -v -t .
42
+
43
+ test_cpu_triton :
44
+ name : test-cpu-py${{ matrix.python-version }}-triton-cpu
45
+ strategy :
46
+ fail-fast : true
47
+ matrix :
48
+ python-version : ["3.12"]
49
+ include :
50
+ - name : A10G
51
+ runs-on : linux.g5.4xlarge.nvidia.gpu
52
+ torch-spec : ' --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
53
+ gpu-arch-type : " cuda"
54
+ gpu-arch-version : " 12.6"
55
+ uses : pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
56
+ with :
57
+ runner : ${{ matrix.runs-on }}
58
+ gpu-arch-type : ${{ matrix.gpu-arch-type }}
59
+ gpu-arch-version : ${{ matrix.gpu-arch-version }}
60
+ submodules : recursive
61
+ script : |
62
+ conda create -n venv python=${{ matrix.python-version }} -y
63
+ conda activate venv
64
+ python -m pip install --upgrade pip
65
+ pip install ${{ matrix.torch-spec }}
66
+ ./.github/scripts/install_triton.sh --cpu
67
+ pip install -r requirements.txt
68
+ TRITON_CPU_BACKEND=1 python -m unittest discover -s test/ -p "*.py" -v -t .
Original file line number Diff line number Diff line change 1
1
from __future__ import annotations
2
2
3
3
import importlib
4
+ import os
4
5
import sys
5
6
from typing import TYPE_CHECKING
6
7
15
16
from .runtime .kernel import Kernel
16
17
17
18
18
- DEVICE = torch .device ("cuda" )
19
+ DEVICE = torch .device ("cuda" if os . environ . get ( "TRITON_CPU_BACKEND" ) != "1" else "cpu" )
19
20
20
21
21
22
def import_path (filename : Path ) -> types .ModuleType :
You can’t perform that action at this time.
0 commit comments