Skip to content

Commit 5caecf2

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_precision_recall_curve #3335 (#3368)
* adds available_device to test_precision_recall_curve #3335 * forces float32 when converting to tensor on mps * creates the data directly with torch tensors instead of numpy arrays * ensures compatibility with MPS by converting to float32 * comments on float32 conversion * makes sure that sklearn does not convert float32 to float64 (otherwize type error on on MPS) * another attempt of avoiding float64 * avoiding float64 for MPS * avoiding float64 for MPS * another attempt at avoiding float64 on MPS * moves conversion to float32 before assertions * conversion to float32 * more conversion to float32 * more conversion to float32 * more conversion to float32 * more conversion to float32 * in precision_recall_curve.py add dtype when creating tensors for precision, recall and thresholds * removes unnecessary conversions * move tensors to CPU before passing them to precision_recall_curve * move tensors to CPU before passing them to precision_recall_curve * move tensors to CPU before passing them to precision_recall_curve * replace np.testing.assert_array_almost_equal with pytest.approx * removes manual_seed --------- Co-authored-by: vfdev <[email protected]>
1 parent de11279 commit 5caecf2

File tree

2 files changed

+77
-65
lines changed

2 files changed

+77
-65
lines changed

ignite/metrics/precision_recall_curve.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i
110110
if idist.get_rank() == 0:
111111
# Run compute_fn on zero rank only
112112
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
113-
precision = torch.tensor(precision, device=_prediction_tensor.device)
114-
recall = torch.tensor(recall, device=_prediction_tensor.device)
113+
precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype)
114+
recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype)
115115
# thresholds can have negative strides, not compatible with torch tensors
116116
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
117-
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
117+
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device, dtype=self._double_dtype)
118118
else:
119119
precision, recall, thresholds = None, None, None
120120

tests/ignite/metrics/test_precision_recall_curve.py

Lines changed: 74 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Tuple
33
from unittest.mock import patch
44

5-
import numpy as np
65
import pytest
76
import sklearn
87
import torch
@@ -28,85 +27,97 @@ def test_no_sklearn(mock_no_sklearn):
2827
pr_curve.compute()
2928

3029

31-
def test_precision_recall_curve():
30+
def test_precision_recall_curve(available_device):
3231
size = 100
33-
np_y_pred = np.random.rand(size, 1)
34-
np_y = np.zeros((size,))
35-
np_y[size // 2 :] = 1
36-
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
32+
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
33+
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
34+
y_true[size // 2 :] = 1.0
35+
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
36+
y_true.cpu().numpy(), y_pred.cpu().numpy()
37+
)
3738

38-
precision_recall_curve_metric = PrecisionRecallCurve()
39-
y_pred = torch.from_numpy(np_y_pred)
40-
y = torch.from_numpy(np_y)
39+
precision_recall_curve_metric = PrecisionRecallCurve(device=available_device)
40+
assert precision_recall_curve_metric._device == torch.device(available_device)
4141

42-
precision_recall_curve_metric.update((y_pred, y))
42+
precision_recall_curve_metric.update((y_pred, y_true))
4343
precision, recall, thresholds = precision_recall_curve_metric.compute()
44-
precision = precision.numpy()
45-
recall = recall.numpy()
46-
thresholds = thresholds.numpy()
4744

48-
assert pytest.approx(precision) == sk_precision
49-
assert pytest.approx(recall) == sk_recall
50-
# assert thresholds almost equal, due to numpy->torch->numpy conversion
51-
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
45+
precision = precision.cpu().numpy()
46+
recall = recall.cpu().numpy()
47+
thresholds = thresholds.cpu().numpy()
48+
49+
assert pytest.approx(precision) == expected_precision
50+
assert pytest.approx(recall) == expected_recall
51+
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)
5252

5353

54-
def test_integration_precision_recall_curve_with_output_transform():
55-
np.random.seed(1)
54+
def test_integration_precision_recall_curve_with_output_transform(available_device):
5655
size = 100
57-
np_y_pred = np.random.rand(size, 1)
58-
np_y = np.zeros((size,))
59-
np_y[size // 2 :] = 1
60-
np.random.shuffle(np_y)
56+
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
57+
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
58+
y_true[size // 2 :] = 1.0
59+
perm = torch.randperm(size)
60+
y_pred = y_pred[perm]
61+
y_true = y_true[perm]
6162

62-
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred)
63+
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
64+
y_true.cpu().numpy(), y_pred.cpu().numpy()
65+
)
6366

6467
batch_size = 10
6568

6669
def update_fn(engine, batch):
6770
idx = (engine.state.iteration - 1) * batch_size
68-
y_true_batch = np_y[idx : idx + batch_size]
69-
y_pred_batch = np_y_pred[idx : idx + batch_size]
70-
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
71+
y_true_batch = y_true[idx : idx + batch_size]
72+
y_pred_batch = y_pred[idx : idx + batch_size]
73+
return idx, y_pred_batch, y_true_batch
7174

7275
engine = Engine(update_fn)
7376

74-
precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (x[1], x[2]))
77+
precision_recall_curve_metric = PrecisionRecallCurve(
78+
output_transform=lambda x: (x[1], x[2]), device=available_device
79+
)
80+
assert precision_recall_curve_metric._device == torch.device(available_device)
7581
precision_recall_curve_metric.attach(engine, "precision_recall_curve")
7682

7783
data = list(range(size // batch_size))
7884
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
79-
precision = precision.numpy()
80-
recall = recall.numpy()
81-
thresholds = thresholds.numpy()
82-
assert pytest.approx(precision) == sk_precision
83-
assert pytest.approx(recall) == sk_recall
84-
# assert thresholds almost equal, due to numpy->torch->numpy conversion
85-
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
85+
precision = precision.cpu().numpy()
86+
recall = recall.cpu().numpy()
87+
thresholds = thresholds.cpu().numpy()
88+
assert pytest.approx(precision) == expected_precision
89+
assert pytest.approx(recall) == expected_recall
90+
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)
8691

8792

88-
def test_integration_precision_recall_curve_with_activated_output_transform():
89-
np.random.seed(1)
93+
def test_integration_precision_recall_curve_with_activated_output_transform(available_device):
9094
size = 100
91-
np_y_pred = np.random.rand(size, 1)
92-
np_y_pred_sigmoid = torch.sigmoid(torch.from_numpy(np_y_pred)).numpy()
93-
np_y = np.zeros((size,))
94-
np_y[size // 2 :] = 1
95-
np.random.shuffle(np_y)
96-
97-
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y, np_y_pred_sigmoid)
95+
y_pred = torch.rand(size, 1, dtype=torch.float32, device=available_device)
96+
y_true = torch.zeros(size, dtype=torch.float32, device=available_device)
97+
y_true[size // 2 :] = 1.0
98+
perm = torch.randperm(size)
99+
y_pred = y_pred[perm]
100+
y_true = y_true[perm]
101+
102+
sigmoid_y_pred = torch.sigmoid(y_pred).cpu().numpy()
103+
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(
104+
y_true.cpu().numpy(), sigmoid_y_pred
105+
)
98106

99107
batch_size = 10
100108

101109
def update_fn(engine, batch):
102110
idx = (engine.state.iteration - 1) * batch_size
103-
y_true_batch = np_y[idx : idx + batch_size]
104-
y_pred_batch = np_y_pred[idx : idx + batch_size]
105-
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
111+
y_true_batch = y_true[idx : idx + batch_size]
112+
y_pred_batch = y_pred[idx : idx + batch_size]
113+
return idx, y_pred_batch, y_true_batch
106114

107115
engine = Engine(update_fn)
108116

109-
precision_recall_curve_metric = PrecisionRecallCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
117+
precision_recall_curve_metric = PrecisionRecallCurve(
118+
output_transform=lambda x: (torch.sigmoid(x[1]), x[2]), device=available_device
119+
)
120+
assert precision_recall_curve_metric._device == torch.device(available_device)
110121
precision_recall_curve_metric.attach(engine, "precision_recall_curve")
111122

112123
data = list(range(size // batch_size))
@@ -115,25 +126,26 @@ def update_fn(engine, batch):
115126
recall = recall.cpu().numpy()
116127
thresholds = thresholds.cpu().numpy()
117128

118-
assert pytest.approx(precision) == sk_precision
119-
assert pytest.approx(recall) == sk_recall
120-
# assert thresholds almost equal, due to numpy->torch->numpy conversion
121-
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
129+
assert pytest.approx(precision) == expected_precision
130+
assert pytest.approx(recall) == expected_recall
131+
assert thresholds == pytest.approx(expected_thresholds, rel=1e-6)
122132

123133

124-
def test_check_compute_fn():
134+
def test_check_compute_fn(available_device):
125135
y_pred = torch.zeros((8, 13))
126136
y_pred[:, 1] = 1
127137
y_true = torch.zeros_like(y_pred)
128138
output = (y_pred, y_true)
129139

130-
em = PrecisionRecallCurve(check_compute_fn=True)
140+
em = PrecisionRecallCurve(check_compute_fn=True, device=available_device)
141+
assert em._device == torch.device(available_device)
131142

132143
em.reset()
133144
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
134145
em.update(output)
135146

136-
em = PrecisionRecallCurve(check_compute_fn=False)
147+
em = PrecisionRecallCurve(check_compute_fn=False, device=available_device)
148+
assert em._device == torch.device(available_device)
137149
em.update(output)
138150

139151

@@ -225,14 +237,14 @@ def update(engine, i):
225237
np_y_true = y_true.cpu().numpy().ravel()
226238
np_y_preds = y_preds.cpu().numpy().ravel()
227239

228-
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)
240+
expected_precision, expected_recall, expected_thresholds = precision_recall_curve(np_y_true, np_y_preds)
229241

230-
assert precision.shape == sk_precision.shape
231-
assert recall.shape == sk_recall.shape
232-
assert thresholds.shape == sk_thresholds.shape
233-
assert pytest.approx(precision.cpu().numpy()) == sk_precision
234-
assert pytest.approx(recall.cpu().numpy()) == sk_recall
235-
assert pytest.approx(thresholds.cpu().numpy()) == sk_thresholds
242+
assert precision.shape == expected_precision.shape
243+
assert recall.shape == expected_recall.shape
244+
assert thresholds.shape == expected_thresholds.shape
245+
assert pytest.approx(precision.cpu().numpy()) == expected_precision
246+
assert pytest.approx(recall.cpu().numpy()) == expected_recall
247+
assert pytest.approx(thresholds.cpu().numpy()) == expected_thresholds
236248

237249
metric_devices = ["cpu"]
238250
if device.type != "xla":

0 commit comments

Comments
 (0)