Skip to content

Commit 5ca7283

Browse files
committed
Run CI using Triton CPU backend
stack-info: PR: #174, branch: oulgen/stack/8
1 parent b460e5f commit 5ca7283

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

.github/scripts/install_triton.sh

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
11
#!/bin/bash
22
set -ex
3+
4+
# Parse command line arguments
5+
USE_CPU_BACKEND=false
6+
while [[ $# -gt 0 ]]; do
7+
case $1 in
8+
--cpu)
9+
USE_CPU_BACKEND=true
10+
shift
11+
;;
12+
*)
13+
echo "Unknown option: $1"
14+
exit 1
15+
;;
16+
esac
17+
done
18+
319
(
420
mkdir -p /tmp/$USER
521
pushd /tmp/$USER
622
pip uninstall -y triton pytorch-triton || true
723
rm -rf triton/ || true
8-
git clone https://github.com/triton-lang/triton.git # install triton latest main
24+
25+
# Clone the appropriate repository based on backend
26+
if [ "$USE_CPU_BACKEND" = true ]; then
27+
# Install triton-cpu from triton-cpu repository
28+
git clone --recursive https://github.com/triton-lang/triton-cpu.git triton
29+
else
30+
# Install triton from main repository for GPU backend
31+
git clone https://github.com/triton-lang/triton.git triton
32+
fi
33+
34+
# Shared build process for both backends
935
(
1036
pushd triton/
1137
conda config --set channel_priority strict
@@ -14,10 +40,14 @@ set -ex
1440
conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13
1541
pip install -r python/requirements.txt
1642
# Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines
17-
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
43+
if [ "$USE_CPU_BACKEND" = true ]; then
44+
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install -e python # install to conda site-packages/ folder
45+
else
46+
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
47+
fi
1848
popd
1949
)
20-
rm -rf triton/
50+
#rm -rf triton/
2151
popd
2252
)
2353
exit 0

.github/workflows/test.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,31 @@ jobs:
3939
./.github/scripts/install_triton.sh
4040
pip install -r requirements.txt
4141
python -m unittest discover -s test/ -p "*.py" -v -t .
42+
43+
test_cpu_triton:
44+
name: test-cpu-py${{ matrix.python-version }}-triton-cpu
45+
strategy:
46+
fail-fast: true
47+
matrix:
48+
python-version: ["3.12"]
49+
include:
50+
- name: A10G
51+
runs-on: linux.g5.4xlarge.nvidia.gpu
52+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
53+
gpu-arch-type: "cuda"
54+
gpu-arch-version: "12.6"
55+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
56+
with:
57+
timeout: 120
58+
runner: ${{ matrix.runs-on }}
59+
gpu-arch-type: ${{ matrix.gpu-arch-type }}
60+
gpu-arch-version: ${{ matrix.gpu-arch-version }}
61+
submodules: recursive
62+
script: |
63+
conda create -n venv python=${{ matrix.python-version }} -y
64+
conda activate venv
65+
python -m pip install --upgrade pip
66+
pip install ${{ matrix.torch-spec }}
67+
./.github/scripts/install_triton.sh --cpu
68+
pip install -r requirements.txt
69+
TRITON_CPU_BACKEND=1 python -m unittest discover -s test/ -p "*.py" -v -t .

helion/_testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import importlib
4+
import os
45
import sys
56
from typing import TYPE_CHECKING
67

@@ -15,7 +16,7 @@
1516
from .runtime.kernel import Kernel
1617

1718

18-
DEVICE = torch.device("cuda")
19+
DEVICE = torch.device("cuda" if os.environ.get("TRITON_CPU_BACKEND") != "1" else "cpu")
1920

2021

2122
def import_path(filename: Path) -> types.ModuleType:

0 commit comments

Comments
 (0)