Skip to content

Commit 0da79f2

Browse files
committed
Run CI using Triton CPU backend
stack-info: PR: #174, branch: oulgen/stack/8
1 parent b460e5f commit 0da79f2

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

.github/scripts/install_triton.sh

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
11
#!/bin/bash
22
set -ex
3+
4+
# Parse command line arguments
5+
USE_CPU_BACKEND=false
6+
while [[ $# -gt 0 ]]; do
7+
case $1 in
8+
--cpu)
9+
USE_CPU_BACKEND=true
10+
shift
11+
;;
12+
*)
13+
echo "Unknown option: $1"
14+
exit 1
15+
;;
16+
esac
17+
done
18+
319
(
420
mkdir -p /tmp/$USER
521
pushd /tmp/$USER
622
pip uninstall -y triton pytorch-triton || true
723
rm -rf triton/ || true
8-
git clone https://github.com/triton-lang/triton.git # install triton latest main
24+
25+
# Clone the appropriate repository based on backend
26+
if [ "$USE_CPU_BACKEND" = true ]; then
27+
# Install triton-cpu from triton-cpu repository
28+
git clone --recursive https://github.com/triton-lang/triton-cpu.git triton
29+
else
30+
# Install triton from main repository for GPU backend
31+
git clone https://github.com/triton-lang/triton.git triton
32+
fi
33+
34+
# Shared build process for both backends
935
(
1036
pushd triton/
1137
conda config --set channel_priority strict
@@ -14,7 +40,11 @@ set -ex
1440
conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13
1541
pip install -r python/requirements.txt
1642
# Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines
17-
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
43+
if [ "$USE_CPU_BACKEND" = true ]; then
44+
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install ./python # install to conda site-packages/ folder
45+
else
46+
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
47+
fi
1848
popd
1949
)
2050
rm -rf triton/

.github/workflows/test.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,27 @@ jobs:
3939
./.github/scripts/install_triton.sh
4040
pip install -r requirements.txt
4141
python -m unittest discover -s test/ -p "*.py" -v -t .
42+
43+
test_cpu_triton:
44+
name: test-cpu-py${{ matrix.python-version }}-triton-cpu
45+
strategy:
46+
fail-fast: true
47+
matrix:
48+
python-version: ["3.12"]
49+
include:
50+
- name: CPU
51+
runs-on: linux.2xlarge
52+
torch-spec: 'torch'
53+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
54+
with:
55+
timeout: 120
56+
runner: ${{ matrix.runs-on }}
57+
submodules: recursive
58+
script: |
59+
conda create -n venv python=${{ matrix.python-version }} -y
60+
conda activate venv
61+
python -m pip install --upgrade pip
62+
pip install ${{ matrix.torch-spec }}
63+
./.github/scripts/install_triton.sh --cpu
64+
pip install -r requirements.txt
65+
TRITON_CPU_BACKEND=1 python -m unittest discover -s test/ -p "*.py" -v -t .

helion/_testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import importlib
4+
import os
45
import sys
56
from typing import TYPE_CHECKING
67

@@ -15,7 +16,7 @@
1516
from .runtime.kernel import Kernel
1617

1718

18-
DEVICE = torch.device("cuda")
19+
DEVICE = torch.device("cuda" if os.environ.get("TRITON_CPU_BACKEND") != "1" else "cpu")
1920

2021

2122
def import_path(filename: Path) -> types.ModuleType:

0 commit comments

Comments
 (0)