1
1
from typing import Callable , cast , Dict , List , Optional , Sequence , Tuple , Union
2
2
3
3
import torch
4
+ from packaging .version import Version
4
5
from typing_extensions import Literal
5
6
6
7
from ignite .metrics import MetricGroup
9
10
from ignite .metrics .metric import Metric , reinit__is_reduced , sync_all_reduce
10
11
11
12
13
+ _torch_version_lt_113 = Version (torch .__version__ ) < Version ("1.13.0" )
14
+
15
+
12
16
def coco_tensor_list_to_dict_list (
13
17
output : Tuple [
14
18
Union [List [torch .Tensor ], List [Dict [str , torch .Tensor ]]],
@@ -213,7 +217,8 @@ def _compute_recall_and_precision(
213
217
Returns:
214
218
`(recall, precision)`
215
219
"""
216
- indices = torch .argsort (scores , dim = - 1 , stable = True , descending = True )
220
+ kwargs = {} if _torch_version_lt_113 else {"stable" : True }
221
+ indices = torch .argsort (scores , descending = True , ** kwargs )
217
222
tp = TP [..., indices ]
218
223
tp_summation = tp .cumsum (dim = - 1 )
219
224
if tp_summation .device .type != "mps" :
@@ -226,7 +231,7 @@ def _compute_recall_and_precision(
226
231
227
232
recall = tp_summation / y_true_count
228
233
predicted_positive = tp_summation + fp_summation
229
- precision = tp_summation / torch .where (predicted_positive == 0 , 1 , predicted_positive )
234
+ precision = tp_summation / torch .where (predicted_positive == 0 , 1.0 , predicted_positive )
230
235
231
236
return recall , precision
232
237
@@ -258,9 +263,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
258
263
if recall .size (- 1 ) != 0
259
264
else torch .LongTensor ([], device = self ._device )
260
265
)
261
- precision_integrand = precision_integrand .take_along_dim (
262
- rec_thresh_indices .where (rec_thresh_indices != recall .size (- 1 ), 0 ), dim = - 1
263
- ).where (rec_thresh_indices != recall .size (- 1 ), 0 )
266
+ recall_mask = rec_thresh_indices != recall .size (- 1 )
267
+ precision_integrand = torch .where (
268
+ recall_mask ,
269
+ precision_integrand .take_along_dim (torch .where (recall_mask , rec_thresh_indices , 0 ), dim = - 1 ),
270
+ 0.0 ,
271
+ )
264
272
return torch .sum (precision_integrand , dim = - 1 ) / len (cast (torch .Tensor , self .rec_thresholds ))
265
273
266
274
@reinit__is_reduced
@@ -298,6 +306,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
298
306
This key is optional.
299
307
========= ================= =================================================
300
308
"""
309
+ kwargs = {} if _torch_version_lt_113 else {"stable" : True }
301
310
self ._check_matching_input (output )
302
311
for pred , target in zip (* output ):
303
312
labels = target ["labels" ]
@@ -312,7 +321,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
312
321
313
322
# Matching logic of object detection mAP, according to COCO reference implementation.
314
323
if len (pred ["labels" ]):
315
- best_detections_index = torch .argsort (pred ["scores" ], stable = True , descending = True )
324
+ best_detections_index = torch .argsort (pred ["scores" ], descending = True , ** kwargs )
316
325
max_best_detections_index = torch .cat (
317
326
[
318
327
best_detections_index [pred ["labels" ][best_detections_index ] == c ][
0 commit comments