Skip to content

Commit 9444581

Browse files
committed
Fixes for pytorch<2.0 in average precision
1 parent 8a34103 commit 9444581

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

ignite/metrics/mean_average_precision.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, cast, List, Optional, Sequence, Tuple, Union
33

44
import torch
5+
from packaging.version import Version
56
from typing_extensions import Literal
67

78
import ignite.distributed as idist
@@ -11,6 +12,9 @@
1112
from ignite.utils import to_onehot
1213

1314

15+
_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0")
16+
17+
1418
class _BaseAveragePrecision:
1519
def __init__(
1620
self,
@@ -97,9 +101,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
97101
if self.rec_thresholds is not None:
98102
rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1))
99103
rec_thresh_indices = torch.searchsorted(recall, rec_thresholds)
100-
precision = precision.take_along_dim(
101-
rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1
102-
).where(rec_thresh_indices != recall.size(-1), 0)
104+
rec_mask = rec_thresh_indices != recall.size(-1)
105+
precision = torch.where(
106+
rec_mask,
107+
precision.take_along_dim(torch.where(rec_mask, rec_thresh_indices, 0), dim=-1),
108+
0.0,
109+
)
103110
recall = rec_thresholds
104111
recall_differential = recall.diff(
105112
dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype)
@@ -335,9 +342,10 @@ def _compute_recall_and_precision(
335342
Returns:
336343
`(recall, precision)`
337344
"""
338-
indices = torch.argsort(y_pred, stable=True, descending=True)
345+
kwargs = {} if _torch_version_lt_113 else {"stable": True}
346+
indices = torch.argsort(y_pred, descending=True, **kwargs)
339347
tp_summation = y_true[indices].cumsum(dim=0)
340-
if tp_summation.device != torch.device("mps"):
348+
if tp_summation.device.type != "mps":
341349
tp_summation = tp_summation.double()
342350

343351
# Adopted from Scikit-learn's implementation
@@ -371,7 +379,7 @@ def compute(self) -> Union[torch.Tensor, float]:
371379
torch.long if self._type == "multiclass" else torch.uint8,
372380
self._device,
373381
)
374-
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
382+
fp_precision = torch.double if self._device.type != "mps" else torch.float32
375383
y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device)
376384

377385
if self._type == "multiclass":

ignite/metrics/vision/object_detection_average_precision_recall.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
22

33
import torch
4+
from packaging.version import Version
45
from typing_extensions import Literal
56

67
from ignite.metrics import MetricGroup
@@ -9,6 +10,9 @@
910
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
1011

1112

13+
_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0")
14+
15+
1216
def coco_tensor_list_to_dict_list(
1317
output: Tuple[
1418
Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]],
@@ -213,7 +217,8 @@ def _compute_recall_and_precision(
213217
Returns:
214218
`(recall, precision)`
215219
"""
216-
indices = torch.argsort(scores, dim=-1, stable=True, descending=True)
220+
kwargs = {} if _torch_version_lt_113 else {"stable": True}
221+
indices = torch.argsort(scores, descending=True, **kwargs)
217222
tp = TP[..., indices]
218223
tp_summation = tp.cumsum(dim=-1)
219224
if tp_summation.device.type != "mps":
@@ -226,7 +231,7 @@ def _compute_recall_and_precision(
226231

227232
recall = tp_summation / y_true_count
228233
predicted_positive = tp_summation + fp_summation
229-
precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive)
234+
precision = tp_summation / torch.where(predicted_positive == 0, 1.0, predicted_positive)
230235

231236
return recall, precision
232237

@@ -258,9 +263,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
258263
if recall.size(-1) != 0
259264
else torch.LongTensor([], device=self._device)
260265
)
261-
precision_integrand = precision_integrand.take_along_dim(
262-
rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1
263-
).where(rec_thresh_indices != recall.size(-1), 0)
266+
recall_mask = rec_thresh_indices != recall.size(-1)
267+
precision_integrand = torch.where(
268+
recall_mask,
269+
precision_integrand.take_along_dim(torch.where(recall_mask, rec_thresh_indices, 0), dim=-1),
270+
0.0,
271+
)
264272
return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds))
265273

266274
@reinit__is_reduced
@@ -298,6 +306,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
298306
This key is optional.
299307
========= ================= =================================================
300308
"""
309+
kwargs = {} if _torch_version_lt_113 else {"stable": True}
301310
self._check_matching_input(output)
302311
for pred, target in zip(*output):
303312
labels = target["labels"]
@@ -312,7 +321,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
312321

313322
# Matching logic of object detection mAP, according to COCO reference implementation.
314323
if len(pred["labels"]):
315-
best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)
324+
best_detections_index = torch.argsort(pred["scores"], descending=True, **kwargs)
316325
max_best_detections_index = torch.cat(
317326
[
318327
best_detections_index[pred["labels"][best_detections_index] == c][

tests/ignite/metrics/vision/test_object_detection_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def test__compute_recall_and_precision():
864864
def test_compute(sample):
865865
device = idist.device()
866866

867-
if device == torch.device("mps"):
867+
if device.type == "mps":
868868
pytest.skip("Due to MPS backend out of memory")
869869

870870
# [email protected], [email protected], [email protected], AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L
@@ -924,7 +924,7 @@ def test_integration(sample):
924924
bs = 3
925925

926926
device = idist.device()
927-
if device == torch.device("mps"):
927+
if device.type == "mps":
928928
pytest.skip("Due to MPS backend out of memory")
929929

930930
def update(engine, i):
@@ -995,7 +995,7 @@ def test_distrib_update_compute(distributed, sample):
995995

996996
device = idist.device()
997997

998-
if device == torch.device("mps"):
998+
if device.type == "mps":
999999
pytest.skip("Due to MPS backend out of memory")
10001000

10011001
metric_device = "cpu" if device.type == "xla" else device

0 commit comments

Comments
 (0)