Skip to content

Run CI using Triton CPU backend #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions .github/scripts/install_triton.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
#!/bin/bash
set -ex

# Parse command line arguments
USE_CPU_BACKEND=false
while [[ $# -gt 0 ]]; do
case $1 in
--cpu)
USE_CPU_BACKEND=true
shift
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done

(
mkdir -p /tmp/$USER
pushd /tmp/$USER
pip uninstall -y triton pytorch-triton || true
rm -rf triton/ || true
git clone https://github.com/triton-lang/triton.git # install triton latest main

# Clone the appropriate repository based on backend
if [ "$USE_CPU_BACKEND" = true ]; then
# Install triton-cpu from triton-cpu repository
git clone --recursive https://github.com/triton-lang/triton-cpu.git triton
else
# Install triton from main repository for GPU backend
git clone https://github.com/triton-lang/triton.git triton
fi

# Shared build process for both backends
(
pushd triton/
conda config --set channel_priority strict
Expand All @@ -14,10 +40,14 @@ set -ex
conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13
pip install -r python/requirements.txt
# Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
if [ "$USE_CPU_BACKEND" = true ]; then
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install -e python # install to conda site-packages/ folder
else
MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder
fi
popd
)
rm -rf triton/
#rm -rf triton/
popd
)
exit 0
29 changes: 29 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,32 @@ jobs:
./.github/scripts/install_triton.sh
pip install -r requirements.txt
python -m unittest discover -s test/ -p "*.py" -v -t .

test_cpu_triton:
name: test-cpu-py${{ matrix.python-version }}-triton-cpu
strategy:
fail-fast: true
matrix:
python-version: ["3.12"]
include:
- name: A10G
runs-on: linux.g5.4xlarge.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
gpu-arch-type: "cuda"
gpu-arch-version: "12.6"
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 120
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=${{ matrix.python-version }} -y
conda activate venv
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
time ./.github/scripts/install_triton.sh --cpu
pip install -r requirements.txt
pip install pytest pytest-timeout
TRITON_CPU_BACKEND=1 pytest --timeout 60 test
18 changes: 17 additions & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import functools
import importlib
import os
import sys
from typing import TYPE_CHECKING
from typing import Callable
import unittest

import torch

Expand All @@ -15,7 +19,19 @@
from .runtime.kernel import Kernel


DEVICE = torch.device("cuda")
USE_TRITON_CPU_BACKEND: bool = os.environ.get("TRITON_CPU_BACKEND", "0") == "1"

if USE_TRITON_CPU_BACKEND:
DEVICE = torch.device("cpu")
else:
DEVICE = torch.device("cuda")


skipIfTritonCpu: Callable[[Callable[..., object]], Callable[..., object]] = (
functools.partial(
unittest.skipIf, USE_TRITON_CPU_BACKEND, "does not work with triton cpu"
)
)


def import_path(filename: Path) -> types.ModuleType:
Expand Down
3 changes: 3 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from helion._testing import DEVICE
from helion._testing import code_and_output
from helion._testing import import_path
from helion._testing import skipIfTritonCpu
import helion.language as hl

datadir = Path(__file__).parent / "data"
Expand Down Expand Up @@ -154,6 +155,7 @@ def _device_loop_3d_make_precompiler(x: torch.Tensor):
return make_precompiler(_device_loop_3d_kernel)(x, out, out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), b, c, d, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)

@skipIfTritonCpu()
def test_3d_device_loop1(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand Down Expand Up @@ -263,6 +265,7 @@ def _device_loop_3d_make_precompiler(x: torch.Tensor):
return make_precompiler(_device_loop_3d_kernel)(x, out, out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), a, b, c, d, _BLOCK_SIZE_0, _BLOCK_SIZE_1_2_3, num_warps=4, num_stages=3)""",
)

@skipIfTritonCpu()
def test_3d_device_loop3(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand Down
2 changes: 2 additions & 0 deletions test/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import helion
from helion._testing import DEVICE
from helion._testing import code_and_output
from helion._testing import skipIfTritonCpu
import helion.language as hl


Expand Down Expand Up @@ -107,6 +108,7 @@ def run_test_with_and_without_triton_interpret_envvar(self, test_func):
else:
os.environ["TRITON_INTERPRET"] = original_env

@skipIfTritonCpu()
def test_basic_print(self):
"""Test basic print with prefix and tensor values"""

Expand Down
Loading