Skip to content

Commit 1d48963

Browse files
committed
NVIDIA arch-specific cuda and torch versions
1 parent 1768cd9 commit 1d48963

File tree

3 files changed

+90
-20
lines changed

3 files changed

+90
-20
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ The list of platforms on which `torchruntime` can install a working variant of P
7777
| 40xx | ✅ Yes | Win/Linux | Uses CUDA 12.8 |
7878
| 30xx | ✅ Yes | Win/Linux | Uses CUDA 12.8 |
7979
| 20xx | ✅ Yes | Win/Linux | Uses CUDA 12.8 |
80-
| 10xx/16xx | ✅ Yes | Win/Linux | Uses CUDA 12.8. Full-precision required on 16xx series |
80+
| 16xx | ✅ Yes | Win/Linux | Uses CUDA 12.8. Requires full-precision for image generation |
81+
| 10xx | ✅ Yes | Win/Linux | Uses CUDA 12.4 |
82+
| 7xx | ✅ Yes | Win/Linux | Uses CUDA 11.8 |
8183

82-
**Note:** We use CUDA 12.4 for Python 3.8, since torch dropped support for Python 3.8 after torch 2.4.
84+
**Note:** Torch dropped support for Python 3.8 from torch >= 2.5. torchruntime falls back to CUDA 12.4, if python 3.8 is being used.
8385

8486
### AMD
8587

tests/test_platform_detection.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,26 +94,20 @@ def test_amd_gpu_mac(monkeypatch):
9494
assert get_torch_platform(gpu_infos) == "mps"
9595

9696

97-
def test_nvidia_gpu_windows(monkeypatch, capsys):
97+
def test_nvidia_gpu_windows(monkeypatch):
9898
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
9999
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
100100
gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]
101101
expected = "cu124" if py_version < (3, 9) else "cu128"
102102
assert get_torch_platform(gpu_infos) == expected
103-
if py_version < (3, 9):
104-
captured = capsys.readouterr()
105-
assert "Support for Python 3.8 was dropped in torch 2.5" in captured.out
106103

107104

108-
def test_nvidia_gpu_linux(monkeypatch, capsys):
105+
def test_nvidia_gpu_linux(monkeypatch):
109106
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
110107
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
111108
gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]
112109
expected = "cu124" if py_version < (3, 9) else "cu128"
113110
assert get_torch_platform(gpu_infos) == expected
114-
if py_version < (3, 9):
115-
captured = capsys.readouterr()
116-
assert "Support for Python 3.8 was dropped in torch 2.5" in captured.out
117111

118112

119113
def test_nvidia_gpu_mac(monkeypatch):
@@ -124,6 +118,52 @@ def test_nvidia_gpu_mac(monkeypatch):
124118
get_torch_platform(gpu_infos)
125119

126120

121+
def test_nvidia_7xx_gpu_windows(monkeypatch):
122+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
123+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
124+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "1004", "GK110 [GeForce GTX 780]", True)]
125+
assert get_torch_platform(gpu_infos) == "cu118"
126+
127+
128+
def test_nvidia_10xx_gpu_windows(monkeypatch):
129+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
130+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
131+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "1c02", "GP106 [GeForce GTX 1060 3GB]", True)]
132+
assert get_torch_platform(gpu_infos) == "cu124"
133+
134+
135+
def test_nvidia_16xx_gpu_windows(monkeypatch):
136+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
137+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
138+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "21c4", "TU116 [GeForce GTX 1660 SUPER]", True)]
139+
expected = "cu124" if py_version < (3, 9) else "cu128"
140+
assert get_torch_platform(gpu_infos) == expected
141+
142+
143+
def test_nvidia_20xx_gpu_windows(monkeypatch):
144+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
145+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
146+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "1f11", "TU106M [GeForce RTX 2060 Mobile]", True)]
147+
expected = "cu124" if py_version < (3, 9) else "cu128"
148+
assert get_torch_platform(gpu_infos) == expected
149+
150+
151+
def test_nvidia_30xx_gpu_windows(monkeypatch):
152+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
153+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
154+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "2489", "GA104 [GeForce RTX 3060 Ti Lite Hash Rate]", True)]
155+
expected = "cu124" if py_version < (3, 9) else "cu128"
156+
assert get_torch_platform(gpu_infos) == expected
157+
158+
159+
def test_nvidia_40xx_gpu_windows(monkeypatch):
160+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
161+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
162+
gpu_infos = [GPU(NVIDIA, "NVIDIA", "2705", "AD103 [GeForce RTX 4070 Ti SUPER]", True)]
163+
expected = "cu124" if py_version < (3, 9) else "cu128"
164+
assert get_torch_platform(gpu_infos) == expected
165+
166+
127167
def test_nvidia_5xxx_gpu_windows(monkeypatch):
128168
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
129169
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")

torchruntime/platform_detection.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,26 @@
99
arch = platform.machine().lower()
1010
py_version = sys.version_info
1111

12-
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b")
12+
# https://www.techpowerup.com/gpu-specs/?architecture=Kepler&sort=generation and so on (change the arch field)
13+
KEPLER_DEVICES = re.compile(r"\b(gk1\d{2}\w*)\b", re.IGNORECASE) # sm3.7
14+
MAXWELL_DEVICES = re.compile(r"\b(gm10\d\w*)\b", re.IGNORECASE) # sm5
15+
PASCAL_DEVICES = re.compile(r"\b(gp10\d\w*)\b", re.IGNORECASE) # sm6
16+
VOLTA_DEVICES = re.compile(r"\b(gv100\w*)\b", re.IGNORECASE) # sm7
17+
TURING_DEVICES = re.compile(r"\b(tu1\d{2}\w*)\b", re.IGNORECASE) # sm7.5
18+
AMPERE_DEVICES = re.compile(r"\b(ga10\d\w*)\b", re.IGNORECASE) # sm8.6
19+
ADA_LOVELACE_DEVICES = re.compile(r"\b(ad10\d\w*)\b", re.IGNORECASE) # sm8.9
20+
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b", re.IGNORECASE) # sm10, sm12
21+
22+
NVIDIA_ARCH_MAP = {
23+
BLACKWELL_DEVICES: 12,
24+
ADA_LOVELACE_DEVICES: 8.9,
25+
AMPERE_DEVICES: 8.6,
26+
TURING_DEVICES: 7.5,
27+
VOLTA_DEVICES: 7,
28+
PASCAL_DEVICES: 6,
29+
MAXWELL_DEVICES: 5,
30+
KEPLER_DEVICES: 3.7,
31+
}
1332

1433

1534
def get_torch_platform(gpu_infos):
@@ -109,16 +128,17 @@ def _get_platform_for_discrete(gpu_infos):
109128
return "mps"
110129
elif vendor_id == NVIDIA:
111130
if os_name in ("Windows", "Linux"):
112-
if py_version < (3, 9):
113-
device_names = set(gpu.device_name for gpu in gpu_infos)
114-
if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):
115-
raise NotImplementedError(
116-
f"Torch does not support NVIDIA 50xx series of GPUs on Python 3.8. Please switch to a newer Python version to use the latest version of torch!"
117-
)
118-
119-
print(
120-
"[WARNING] Support for Python 3.8 was dropped in torch 2.5. torchruntime will default to using torch 2.4 instead, but consider switching to a newer Python version to use the latest version of torch!"
131+
device_names = set(gpu.device_name for gpu in gpu_infos)
132+
arch_version = get_nvidia_arch(device_names)
133+
if py_version < (3, 9) and arch_version == 12:
134+
raise NotImplementedError(
135+
f"Torch does not support NVIDIA 50xx series of GPUs on Python 3.8. Please switch to a newer Python version to use the latest version of torch!"
121136
)
137+
138+
# https://github.com/pytorch/pytorch/blob/0b6ea0b959f65d53ea8a34c1fa1c46446dfe3603/.ci/manywheel/build_cuda.sh#L54
139+
if arch_version == 3.7:
140+
return "cu118"
141+
if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
122142
return "cu124"
123143

124144
return "cu128"
@@ -151,6 +171,14 @@ def _get_platform_for_discrete(gpu_infos):
151171
return "cpu"
152172

153173

174+
def get_nvidia_arch(device_names):
175+
for arch_regex, arch in NVIDIA_ARCH_MAP.items():
176+
if any(arch_regex.search(device_name) for device_name in device_names):
177+
return arch
178+
179+
return 0
180+
181+
154182
def _get_platform_for_integrated(gpu_infos):
155183
gpu = gpu_infos[0]
156184

0 commit comments

Comments
 (0)