File tree Expand file tree Collapse file tree 3 files changed +58
-3
lines changed Expand file tree Collapse file tree 3 files changed +58
-3
lines changed Original file line number Diff line number Diff line change 1
1
#! /bin/bash
2
2
set -ex
3
+
4
+ # Parse command line arguments
5
+ USE_CPU_BACKEND=false
6
+ while [[ $# -gt 0 ]]; do
7
+ case $1 in
8
+ --cpu)
9
+ USE_CPU_BACKEND=true
10
+ shift
11
+ ;;
12
+ * )
13
+ echo " Unknown option: $1 "
14
+ exit 1
15
+ ;;
16
+ esac
17
+ done
18
+
3
19
(
4
20
mkdir -p /tmp/$USER
5
21
pushd /tmp/$USER
6
22
pip uninstall -y triton pytorch-triton || true
7
23
rm -rf triton/ || true
8
- git clone https://github.com/triton-lang/triton.git # install triton latest main
24
+
25
+ # Clone the appropriate repository based on backend
26
+ if [ " $USE_CPU_BACKEND " = true ]; then
27
+ # Install triton-cpu from triton-cpu repository
28
+ git clone --recursive https://github.com/triton-lang/triton-cpu.git triton
29
+ else
30
+ # Install triton from main repository for GPU backend
31
+ git clone https://github.com/triton-lang/triton.git triton
32
+ fi
33
+
34
+ # Shared build process for both backends
9
35
(
10
36
pushd triton/
11
37
conda config --set channel_priority strict
@@ -14,7 +40,11 @@ set -ex
14
40
conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13
15
41
pip install -r python/requirements.txt
16
42
# Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines
17
- MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
43
+ if [ " $USE_CPU_BACKEND " = true ]; then
44
+ MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install ./python # install to conda site-packages/ folder
45
+ else
46
+ MAX_JOBS=$( nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
47
+ fi
18
48
popd
19
49
)
20
50
rm -rf triton/
Original file line number Diff line number Diff line change 39
39
./.github/scripts/install_triton.sh
40
40
pip install -r requirements.txt
41
41
python -m unittest discover -s test/ -p "*.py" -v -t .
42
+
43
+ test_cpu_triton :
44
+ name : test-cpu-py${{ matrix.python-version }}-triton-cpu
45
+ strategy :
46
+ fail-fast : true
47
+ matrix :
48
+ python-version : ["3.12"]
49
+ include :
50
+ - name : CPU
51
+ runs-on : linux.2xlarge
52
+ torch-spec : ' torch'
53
+ uses : pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
54
+ with :
55
+ timeout : 120
56
+ runner : ${{ matrix.runs-on }}
57
+ submodules : recursive
58
+ script : |
59
+ conda create -n venv python=${{ matrix.python-version }} -y
60
+ conda activate venv
61
+ python -m pip install --upgrade pip
62
+ pip install ${{ matrix.torch-spec }}
63
+ ./.github/scripts/install_triton.sh --cpu
64
+ pip install -r requirements.txt
65
+ TRITON_CPU_BACKEND=1 python -m unittest discover -s test/ -p "*.py" -v -t .
Original file line number Diff line number Diff line change 1
1
from __future__ import annotations
2
2
3
3
import importlib
4
+ import os
4
5
import sys
5
6
from typing import TYPE_CHECKING
6
7
15
16
from .runtime .kernel import Kernel
16
17
17
18
18
- DEVICE = torch .device ("cuda" )
19
+ DEVICE = torch .device ("cuda" if os . environ . get ( "TRITON_CPU_BACKEND" ) != "1" else "cpu" )
19
20
20
21
21
22
def import_path (filename : Path ) -> types .ModuleType :
You can’t perform that action at this time.
0 commit comments