2
2
from typing import Tuple
3
3
from unittest .mock import patch
4
4
5
- import numpy as np
6
5
import pytest
7
6
import sklearn
8
7
import torch
@@ -28,85 +27,97 @@ def test_no_sklearn(mock_no_sklearn):
28
27
pr_curve .compute ()
29
28
30
29
31
- def test_precision_recall_curve ():
30
+ def test_precision_recall_curve (available_device ):
32
31
size = 100
33
- np_y_pred = np .random .rand (size , 1 )
34
- np_y = np .zeros ((size ,))
35
- np_y [size // 2 :] = 1
36
- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred )
32
+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
33
+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
34
+ y_true [size // 2 :] = 1.0
35
+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
36
+ y_true .cpu ().numpy (), y_pred .cpu ().numpy ()
37
+ )
37
38
38
- precision_recall_curve_metric = PrecisionRecallCurve ()
39
- y_pred = torch .from_numpy (np_y_pred )
40
- y = torch .from_numpy (np_y )
39
+ precision_recall_curve_metric = PrecisionRecallCurve (device = available_device )
40
+ assert precision_recall_curve_metric ._device == torch .device (available_device )
41
41
42
- precision_recall_curve_metric .update ((y_pred , y ))
42
+ precision_recall_curve_metric .update ((y_pred , y_true ))
43
43
precision , recall , thresholds = precision_recall_curve_metric .compute ()
44
- precision = precision .numpy ()
45
- recall = recall .numpy ()
46
- thresholds = thresholds .numpy ()
47
44
48
- assert pytest .approx (precision ) == sk_precision
49
- assert pytest .approx (recall ) == sk_recall
50
- # assert thresholds almost equal, due to numpy->torch->numpy conversion
51
- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
45
+ precision = precision .cpu ().numpy ()
46
+ recall = recall .cpu ().numpy ()
47
+ thresholds = thresholds .cpu ().numpy ()
48
+
49
+ assert pytest .approx (precision ) == expected_precision
50
+ assert pytest .approx (recall ) == expected_recall
51
+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
52
52
53
53
54
- def test_integration_precision_recall_curve_with_output_transform ():
55
- np .random .seed (1 )
54
+ def test_integration_precision_recall_curve_with_output_transform (available_device ):
56
55
size = 100
57
- np_y_pred = np .random .rand (size , 1 )
58
- np_y = np .zeros ((size ,))
59
- np_y [size // 2 :] = 1
60
- np .random .shuffle (np_y )
56
+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
57
+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
58
+ y_true [size // 2 :] = 1.0
59
+ perm = torch .randperm (size )
60
+ y_pred = y_pred [perm ]
61
+ y_true = y_true [perm ]
61
62
62
- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred )
63
+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
64
+ y_true .cpu ().numpy (), y_pred .cpu ().numpy ()
65
+ )
63
66
64
67
batch_size = 10
65
68
66
69
def update_fn (engine , batch ):
67
70
idx = (engine .state .iteration - 1 ) * batch_size
68
- y_true_batch = np_y [idx : idx + batch_size ]
69
- y_pred_batch = np_y_pred [idx : idx + batch_size ]
70
- return idx , torch . from_numpy ( y_pred_batch ), torch . from_numpy ( y_true_batch )
71
+ y_true_batch = y_true [idx : idx + batch_size ]
72
+ y_pred_batch = y_pred [idx : idx + batch_size ]
73
+ return idx , y_pred_batch , y_true_batch
71
74
72
75
engine = Engine (update_fn )
73
76
74
- precision_recall_curve_metric = PrecisionRecallCurve (output_transform = lambda x : (x [1 ], x [2 ]))
77
+ precision_recall_curve_metric = PrecisionRecallCurve (
78
+ output_transform = lambda x : (x [1 ], x [2 ]), device = available_device
79
+ )
80
+ assert precision_recall_curve_metric ._device == torch .device (available_device )
75
81
precision_recall_curve_metric .attach (engine , "precision_recall_curve" )
76
82
77
83
data = list (range (size // batch_size ))
78
84
precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
79
- precision = precision .numpy ()
80
- recall = recall .numpy ()
81
- thresholds = thresholds .numpy ()
82
- assert pytest .approx (precision ) == sk_precision
83
- assert pytest .approx (recall ) == sk_recall
84
- # assert thresholds almost equal, due to numpy->torch->numpy conversion
85
- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
85
+ precision = precision .cpu ().numpy ()
86
+ recall = recall .cpu ().numpy ()
87
+ thresholds = thresholds .cpu ().numpy ()
88
+ assert pytest .approx (precision ) == expected_precision
89
+ assert pytest .approx (recall ) == expected_recall
90
+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
86
91
87
92
88
- def test_integration_precision_recall_curve_with_activated_output_transform ():
89
- np .random .seed (1 )
93
+ def test_integration_precision_recall_curve_with_activated_output_transform (available_device ):
90
94
size = 100
91
- np_y_pred = np .random .rand (size , 1 )
92
- np_y_pred_sigmoid = torch .sigmoid (torch .from_numpy (np_y_pred )).numpy ()
93
- np_y = np .zeros ((size ,))
94
- np_y [size // 2 :] = 1
95
- np .random .shuffle (np_y )
96
-
97
- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred_sigmoid )
95
+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
96
+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
97
+ y_true [size // 2 :] = 1.0
98
+ perm = torch .randperm (size )
99
+ y_pred = y_pred [perm ]
100
+ y_true = y_true [perm ]
101
+
102
+ sigmoid_y_pred = torch .sigmoid (y_pred ).cpu ().numpy ()
103
+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
104
+ y_true .cpu ().numpy (), sigmoid_y_pred
105
+ )
98
106
99
107
batch_size = 10
100
108
101
109
def update_fn (engine , batch ):
102
110
idx = (engine .state .iteration - 1 ) * batch_size
103
- y_true_batch = np_y [idx : idx + batch_size ]
104
- y_pred_batch = np_y_pred [idx : idx + batch_size ]
105
- return idx , torch . from_numpy ( y_pred_batch ), torch . from_numpy ( y_true_batch )
111
+ y_true_batch = y_true [idx : idx + batch_size ]
112
+ y_pred_batch = y_pred [idx : idx + batch_size ]
113
+ return idx , y_pred_batch , y_true_batch
106
114
107
115
engine = Engine (update_fn )
108
116
109
- precision_recall_curve_metric = PrecisionRecallCurve (output_transform = lambda x : (torch .sigmoid (x [1 ]), x [2 ]))
117
+ precision_recall_curve_metric = PrecisionRecallCurve (
118
+ output_transform = lambda x : (torch .sigmoid (x [1 ]), x [2 ]), device = available_device
119
+ )
120
+ assert precision_recall_curve_metric ._device == torch .device (available_device )
110
121
precision_recall_curve_metric .attach (engine , "precision_recall_curve" )
111
122
112
123
data = list (range (size // batch_size ))
@@ -115,25 +126,26 @@ def update_fn(engine, batch):
115
126
recall = recall .cpu ().numpy ()
116
127
thresholds = thresholds .cpu ().numpy ()
117
128
118
- assert pytest .approx (precision ) == sk_precision
119
- assert pytest .approx (recall ) == sk_recall
120
- # assert thresholds almost equal, due to numpy->torch->numpy conversion
121
- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
129
+ assert pytest .approx (precision ) == expected_precision
130
+ assert pytest .approx (recall ) == expected_recall
131
+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
122
132
123
133
124
- def test_check_compute_fn ():
134
+ def test_check_compute_fn (available_device ):
125
135
y_pred = torch .zeros ((8 , 13 ))
126
136
y_pred [:, 1 ] = 1
127
137
y_true = torch .zeros_like (y_pred )
128
138
output = (y_pred , y_true )
129
139
130
- em = PrecisionRecallCurve (check_compute_fn = True )
140
+ em = PrecisionRecallCurve (check_compute_fn = True , device = available_device )
141
+ assert em ._device == torch .device (available_device )
131
142
132
143
em .reset ()
133
144
with pytest .warns (EpochMetricWarning , match = r"Probably, there can be a problem with `compute_fn`" ):
134
145
em .update (output )
135
146
136
- em = PrecisionRecallCurve (check_compute_fn = False )
147
+ em = PrecisionRecallCurve (check_compute_fn = False , device = available_device )
148
+ assert em ._device == torch .device (available_device )
137
149
em .update (output )
138
150
139
151
@@ -225,14 +237,14 @@ def update(engine, i):
225
237
np_y_true = y_true .cpu ().numpy ().ravel ()
226
238
np_y_preds = y_preds .cpu ().numpy ().ravel ()
227
239
228
- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y_true , np_y_preds )
240
+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (np_y_true , np_y_preds )
229
241
230
- assert precision .shape == sk_precision .shape
231
- assert recall .shape == sk_recall .shape
232
- assert thresholds .shape == sk_thresholds .shape
233
- assert pytest .approx (precision .cpu ().numpy ()) == sk_precision
234
- assert pytest .approx (recall .cpu ().numpy ()) == sk_recall
235
- assert pytest .approx (thresholds .cpu ().numpy ()) == sk_thresholds
242
+ assert precision .shape == expected_precision .shape
243
+ assert recall .shape == expected_recall .shape
244
+ assert thresholds .shape == expected_thresholds .shape
245
+ assert pytest .approx (precision .cpu ().numpy ()) == expected_precision
246
+ assert pytest .approx (recall .cpu ().numpy ()) == expected_recall
247
+ assert pytest .approx (thresholds .cpu ().numpy ()) == expected_thresholds
236
248
237
249
metric_devices = ["cpu" ]
238
250
if device .type != "xla" :
0 commit comments