@@ -94,26 +94,20 @@ def test_amd_gpu_mac(monkeypatch):
94
94
assert get_torch_platform (gpu_infos ) == "mps"
95
95
96
96
97
- def test_nvidia_gpu_windows (monkeypatch , capsys ):
97
+ def test_nvidia_gpu_windows (monkeypatch ):
98
98
monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
99
99
monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
100
100
gpu_infos = [GPU (NVIDIA , "NVIDIA" , 0x1234 , "GeForce" , True )]
101
101
expected = "cu124" if py_version < (3 , 9 ) else "cu128"
102
102
assert get_torch_platform (gpu_infos ) == expected
103
- if py_version < (3 , 9 ):
104
- captured = capsys .readouterr ()
105
- assert "Support for Python 3.8 was dropped in torch 2.5" in captured .out
106
103
107
104
108
- def test_nvidia_gpu_linux (monkeypatch , capsys ):
105
+ def test_nvidia_gpu_linux (monkeypatch ):
109
106
monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Linux" )
110
107
monkeypatch .setattr ("torchruntime.platform_detection.arch" , "x86_64" )
111
108
gpu_infos = [GPU (NVIDIA , "NVIDIA" , 0x1234 , "GeForce" , True )]
112
109
expected = "cu124" if py_version < (3 , 9 ) else "cu128"
113
110
assert get_torch_platform (gpu_infos ) == expected
114
- if py_version < (3 , 9 ):
115
- captured = capsys .readouterr ()
116
- assert "Support for Python 3.8 was dropped in torch 2.5" in captured .out
117
111
118
112
119
113
def test_nvidia_gpu_mac (monkeypatch ):
@@ -124,6 +118,52 @@ def test_nvidia_gpu_mac(monkeypatch):
124
118
get_torch_platform (gpu_infos )
125
119
126
120
121
+ def test_nvidia_7xx_gpu_windows (monkeypatch ):
122
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
123
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
124
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "1004" , "GK110 [GeForce GTX 780]" , True )]
125
+ assert get_torch_platform (gpu_infos ) == "cu118"
126
+
127
+
128
+ def test_nvidia_10xx_gpu_windows (monkeypatch ):
129
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
130
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
131
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "1c02" , "GP106 [GeForce GTX 1060 3GB]" , True )]
132
+ assert get_torch_platform (gpu_infos ) == "cu124"
133
+
134
+
135
+ def test_nvidia_16xx_gpu_windows (monkeypatch ):
136
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
137
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
138
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "21c4" , "TU116 [GeForce GTX 1660 SUPER]" , True )]
139
+ expected = "cu124" if py_version < (3 , 9 ) else "cu128"
140
+ assert get_torch_platform (gpu_infos ) == expected
141
+
142
+
143
+ def test_nvidia_20xx_gpu_windows (monkeypatch ):
144
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
145
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
146
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "1f11" , "TU106M [GeForce RTX 2060 Mobile]" , True )]
147
+ expected = "cu124" if py_version < (3 , 9 ) else "cu128"
148
+ assert get_torch_platform (gpu_infos ) == expected
149
+
150
+
151
+ def test_nvidia_30xx_gpu_windows (monkeypatch ):
152
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
153
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
154
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "2489" , "GA104 [GeForce RTX 3060 Ti Lite Hash Rate]" , True )]
155
+ expected = "cu124" if py_version < (3 , 9 ) else "cu128"
156
+ assert get_torch_platform (gpu_infos ) == expected
157
+
158
+
159
+ def test_nvidia_40xx_gpu_windows (monkeypatch ):
160
+ monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
161
+ monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
162
+ gpu_infos = [GPU (NVIDIA , "NVIDIA" , "2705" , "AD103 [GeForce RTX 4070 Ti SUPER]" , True )]
163
+ expected = "cu124" if py_version < (3 , 9 ) else "cu128"
164
+ assert get_torch_platform (gpu_infos ) == expected
165
+
166
+
127
167
def test_nvidia_5xxx_gpu_windows (monkeypatch ):
128
168
monkeypatch .setattr ("torchruntime.platform_detection.os_name" , "Windows" )
129
169
monkeypatch .setattr ("torchruntime.platform_detection.arch" , "amd64" )
0 commit comments