Skip to content

Fixes for pytorch<2.0 in average precision #3356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Mar 27, 2025

Description:

  • Fixes for pytorch<2.0 in average precision

https://github.com/pytorch/ignite/actions/runs/14102362634

@github-actions github-actions bot added the module: metrics Metrics module label Mar 27, 2025
@vfdev-5 vfdev-5 force-pushed the fix-pth-versions-ci branch from e1ff650 to 9444581 Compare March 27, 2025 00:54
@vfdev-5 vfdev-5 requested a review from Copilot March 27, 2025 00:58
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR addresses fixes for PyTorch versions below 2.0 in the computation of average precision and recall for object detection metrics. Key changes include conditional use of the "stable" parameter in torch.argsort based on the torch version, updated device type comparisons (using device.type instead of direct torch.device comparisons), and adapting precision aggregation logic to mitigate type issues on different backends.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
ignite/metrics/vision/object_detection_average_precision_recall.py Adds a torch version check and conditionally passes the "stable" flag to torch.argsort; adjusts precision aggregation logic.
ignite/metrics/mean_average_precision.py Introduces a similar torch version check and updates device type checks for consistency.
tests/ignite/metrics/vision/test_object_detection_map.py Updates device comparison to use device.type for handling the MPS backend.
Comments suppressed due to low confidence (4)

ignite/metrics/vision/object_detection_average_precision_recall.py:221

  • Confirm that the conditional kwargs usage for torch.argsort maintains consistent sorting behavior across different torch versions without performance regressions.
indices = torch.argsort(scores, descending=True, **kwargs)

ignite/metrics/vision/object_detection_average_precision_recall.py:267

  • Ensure that the fallback value 0.0 is of the same dtype as precision_integrand to avoid potential type mismatches in the average precision calculation.
precision_integrand = torch.where(
            recall_mask,
            precision_integrand.take_along_dim(torch.where(recall_mask, rec_thresh_indices, 0), dim=-1),
            0.0,
        )

ignite/metrics/mean_average_precision.py:348

  • The updated device type check using device.type is more robust; verify that this pattern is applied consistently across similar device comparisons.
if tp_summation.device.type != "mps":

tests/ignite/metrics/vision/test_object_detection_map.py:867

  • The test now correctly uses device.type for checking the MPS backend; ensure that all MPS-specific skips use this updated checking method.
if device.type == "mps":

@vfdev-5 vfdev-5 force-pushed the fix-pth-versions-ci branch from 9444581 to a7ff713 Compare March 27, 2025 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant