Skip to content

Commit aba1cc9

Browse files
build: extras (cpu, gpu) and github action workflow with gpu runners
1 parent 14fce99 commit aba1cc9

File tree

7 files changed

+317
-69
lines changed

7 files changed

+317
-69
lines changed

.github/workflows/run-tests-cpu.yml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: "[CPU] mostlyai-qa Tests"
2+
3+
on:
4+
workflow_call:
5+
6+
7+
env:
8+
PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring
9+
FORCE_COLOR: "1"
10+
11+
jobs:
12+
run-tests-cpu:
13+
runs-on: ubuntu-latest
14+
permissions:
15+
contents: read
16+
packages: write
17+
steps:
18+
- name: Setup | Checkout
19+
uses: actions/checkout@v4
20+
21+
- name: Setup | Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.10'
25+
26+
- name: Setup | uv
27+
uses: astral-sh/setup-uv@v5
28+
with:
29+
enable-cache: false
30+
python-version: '3.10'
31+
32+
- name: Setup | dependencies
33+
run: uv sync --frozen --extra cpu
34+
- name: Run tests
35+
run: uv run pytest tests/
36+
37+
- name: Build mkdocs
38+
run: uv run mkdocs build --strict

.github/workflows/run-tests-gpu.yml

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: "[GPU] mostlyai-qa Tests"
2+
3+
on:
4+
workflow_call:
5+
6+
7+
env:
8+
PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring
9+
FORCE_COLOR: "1"
10+
11+
jobs:
12+
run-tests-gpu:
13+
runs-on: gha-gpu-public
14+
permissions:
15+
contents: read
16+
packages: write
17+
steps:
18+
- name: Setup | Checkout
19+
uses: actions/checkout@v4
20+
21+
- name: Setup | Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.10'
25+
26+
- name: Setup | uv
27+
uses: astral-sh/setup-uv@v5
28+
with:
29+
enable-cache: false
30+
python-version: '3.10'
31+
32+
- name: Setup | dependencies
33+
run: uv sync --frozen --extra gpu
34+
35+
- name: Run tests
36+
run: uv run pytest tests/
37+
38+
- name: Build mkdocs
39+
run: uv run mkdocs build --strict

.github/workflows/run-tests.yml

-25
This file was deleted.

.github/workflows/workflow.yml

+9-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@ jobs:
1313
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository)
1414
uses: ./.github/workflows/pre-commit-check.yml
1515
secrets: inherit
16-
run-tests:
16+
run-tests-cpu:
1717
if: |
1818
github.event_name == 'push' ||
1919
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository)
20-
uses: ./.github/workflows/run-tests.yml
20+
uses: ./.github/workflows/run-tests-cpu.yml
2121
secrets: inherit
22+
run-tests-gpu:
23+
if: |
24+
github.ref == 'refs/heads/main' ||
25+
startsWith(github.ref, 'refs/tags/') ||
26+
contains(github.event.head_commit.message, '[gpu]')
27+
uses: ./.github/workflows/run-tests-gpu.yml
28+
secrets: inherit

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ The latest release of `mostlyai-qa` can be installed via pip:
2020
pip install -U mostlyai-qa
2121
```
2222

23+
On Linux, one can explicitly install `mostlyai-qa[cpu]` or `mostlyai-qa[gpu]`, for CPU-only or CUDA support respectively
24+
2325
## Quick Start
2426

2527
```python

pyproject.toml

+26
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ dependencies = [
4141
"sentence-transformers>=3.1.0",
4242
"rich>=13.9.4,<14",
4343
"skops>=0.11.0",
44+
"torch>=2.5.1",
45+
]
46+
47+
[project.optional-dependencies]
48+
gpu = [
49+
"torch>=2.5.1",
50+
]
51+
cpu = [
52+
"torch==2.5.1+cpu; sys_platform == 'linux'",
53+
"torch>=2.5.1; sys_platform != 'linux'",
4454
]
4555

4656
[project.urls]
@@ -68,6 +78,22 @@ docs = [
6878

6979
[tool.uv]
7080
default-groups = ["dev", "docs"]
81+
conflicts = [
82+
[
83+
{ extra = "cpu" },
84+
{ extra = "gpu" },
85+
],
86+
]
87+
88+
[[tool.uv.index]]
89+
name = "pytorch-cpu"
90+
url = "https://download.pytorch.org/whl/cpu"
91+
explicit = true
92+
93+
[tool.uv.sources]
94+
torch = [
95+
{ index = "pytorch-cpu", extra = "cpu", marker = "sys_platform == 'linux'"},
96+
]
7197

7298
[tool.hatch.build.targets.sdist]
7399
include = ["mostlyai/qa"]

0 commit comments

Comments
 (0)