Skip to content

Add warning for fused_tasks_and_states compute mode for 2D/List state tensors #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

ge0405
Copy link
Contributor

@ge0405 ge0405 commented Apr 19, 2025

Summary:
I visited all child classes that use RecMetricComputation to see if any is incompatible with the added FUSED_TASKS_AND_STATES_COMPUTATION in D72010614.

As of Apr 16, 2025, searching "(RecMetricComputation" in fbcode resulted in 47 results.

  • 1 is rec_metric in comment (not show in the following table)
  • 42 metric classes defined in torchrec, 29 in OSS, 13 in FB
  • 4 metric classes in customer codebase, e.g. MVAI, admarket (see row 43-46 in the table below)

RecMetricComputation uses state tensors to compute/update results. I looked through all following metrics' state tensors and see if any if them are (1) not Tensor (e.g. List) or (2) 2D Tensor. If so, when these metrics use FUSED_TASKS_AND_STATES_COMPUTATION, the init should show warning (and simply FUSED_TASKS_COMPUTATION will be used) or raise Exception (if these metrics can't allow any fuse mode).

| | dir | metric | type of state tensors | tests/warning |
|1 | oss | auc.py | List | warning |
|2 | oss | tower_qps.py | | |
|3 | oss | precision_session.py | not allow fuse, code | both fuse modes raise exception |
|4 | oss | serving_ne.py | | |
|5 | oss | recall_session.py | not allow fuse, code | both fuse modes raise exception |
|6 | oss | calibration.py | | |
|7 | oss | multiclass_recall.py | 2D tensor| warning |
|8 | oss | ndcg.py | | |
|9 | oss | gauc.py | | |
|10 | oss | tensor_weighted_avg.py | | |
|11 | oss | ne.py | | |
|12 | oss | serving_calibration.py | | |
|13 | oss | segmented_ne.py | 2D tensor| warning |
|14 | oss | scalar.py | | |
|15 | oss | mae.py | | |
|16 | oss | ne_positive.py | | |
|17 | oss | weighted_avg.py | | |
|18 | oss | output.py | | |
|20 | oss | cali_free_ne.py | | |
|21 | oss | unweighted_ne.py | | |
|22 | oss | hindsight_target_pr.py | 1D but not n_tasks, code | can still fuse states |
|23 | oss | rauc.py | List | warning |
|24 | oss | precision.py | | |
|25 | oss | recall.py | | |
|26 | oss | auprc.py | List| warning |
|27 | oss | mse.py | | |
|28 | oss | ctr.py | | |
|29 | oss | accuracy.py | | |
|30 | fb | log_normal_cnll.py | | |
|31 | fb | coarse_grained_multiclass_ne.py | 2D tensor| warning |
|32 | fb | regression_huber.py | | |
|33 | fb | modified_poisson_nll.py | | |
|34 | fb | res_ne.py | | |
|35 | fb | dist_shift.py | | |
|36 | fb | serving_ne.py | | |
|37 | fb | unjoined_calibration.py | | |
|38 | fb | unjoined_ne.py | | |
|39 | fb | serving_calibration.py | | |
|40 | fb | multiclass_ne.py | 2D tensor | warning |
|41 | fb | bucket_metric.py [1] | not allow fuse, code | both fuse modes raise exception |
|42 | fb | bucket_weighted_average_metric.py [2] | shouldn't allow fuse, code | both fuse modes should raise exception |
|43 | MVAI | metrics.py | List | |
|44 | MVAI | ndcg_metrics.py | | |
|45 | admarket | metrics.py | 2D or 3D? | |
|46 | mrs/fm | metrics.py | | |

[1] There are 7 bucket metrics (bucket_calibration, bucket_ctr, bucket_hindsight_target_pr, bucket_mse, bucket_ne, bucket_precision, bucket_recall) inherit BucketMetricComputation defined in bucket_metric.py. All of them have 2D tensors and most shapes are (n_tasks, num_buckets).

[2] There is 1 bucket metric (bucket_weighted_average_logloss) inherit BucketWeightedAverageMetricComputation defined in bucket_weighted_average_metric.py. All the state tensors are 2D (n_tasks, num_buckets).

Differential Revision: D73293593

… tensors

Summary:
I visited all child classes that use `RecMetricComputation` to see if any is incompatible with the added `FUSED_TASKS_AND_STATES_COMPUTATION` in D72010614.

As of Apr 16, 2025, searching "`(RecMetricComputation`" in fbcode resulted in 47 results. 
- 1 is rec_metric in comment (not show in the following table)
- 42 metric classes defined in torchrec, 29 in OSS, 13 in FB 
- 4 metric classes in customer codebase, e.g. MVAI, admarket (see row 43-46 in the table below)

RecMetricComputation uses `state` tensors to compute/update results. I looked through all following metrics' state tensors and see if any if them are (1) not Tensor (e.g. List) or (2) 2D Tensor. If so, when these metrics use `FUSED_TASKS_AND_STATES_COMPUTATION`, the init should show warning (and simply FUSED_TASKS_COMPUTATION will be used) or raise Exception (if these metrics can't allow any fuse mode). 

| | dir | metric | type of state tensors | tests/warning |
|1 | oss | auc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/auc.py?lines=192 ) | warning |
|2 | oss | tower_qps.py | | |
|3 | oss | precision_session.py | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[390e95a4e30dfd1f2c7ad23b2de92b1b7edbdf15]/fbcode/torchrec/metrics/precision_session.py?lines=203) | both fuse modes raise exception |
|4 | oss | serving_ne.py | | |
|5 | oss | recall_session.py | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[390e95a4e30dfd1f2c7ad23b2de92b1b7edbdf15]/fbcode/torchrec/metrics/recall_session.py?lines=242) | both fuse modes raise exception |
|6 | oss | calibration.py | | |
|7 | oss | multiclass_recall.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/multiclass_recall.py?lines=99)| warning |
|8 | oss | ndcg.py | | |
|9 | oss | gauc.py | | |
|10 | oss | tensor_weighted_avg.py | | |
|11 | oss | ne.py | | |
|12 | oss | serving_calibration.py | | |
|13 | oss | segmented_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/segmented_ne.py?lines=187)| warning |
|14 | oss | scalar.py | | |
|15 | oss | mae.py | | |
|16 | oss | ne_positive.py | | |
|17 | oss | weighted_avg.py | | |
|18 | oss | output.py | | |
|20 | oss | cali_free_ne.py | | |
|21 | oss | unweighted_ne.py | | |
|22 | oss | hindsight_target_pr.py | 1D but not n_tasks, [code](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/hindsight_target_pr.py?lines=131) | can still fuse states |
|23 | oss | rauc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/rauc.py?lines=238) | warning |
|24 | oss | precision.py | | |
|25 | oss | recall.py | | |
|26 | oss | auprc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/auprc.py?lines=186)| warning |
|27 | oss | mse.py | | |
|28 | oss | ctr.py | | |
|29 | oss | accuracy.py | | |
|30 | fb | log_normal_cnll.py | | |
|31 | fb | coarse_grained_multiclass_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[8670f823b23ca84165106e7a7bc236d066a03c7d]/fbcode/torchrec/fb/metrics/coarse_grained_multiclass_ne.py?lines=62)| warning |
|32 | fb | regression_huber.py | | |
|33 | fb | modified_poisson_nll.py | | |
|34 | fb | res_ne.py | | |
|35 | fb | dist_shift.py | | |
|36 | fb | serving_ne.py | | |
|37 | fb | unjoined_calibration.py | | |
|38 | fb | unjoined_ne.py | | |
|39 | fb | serving_calibration.py | | |
|40 | fb | multiclass_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/fb/metrics/multiclass_ne.py?lines=148) | warning |
|41 | fb | bucket_metric.py [1] | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[0cc90263a7ae86e5327b153ccfe2f79b5956c69d]/fbcode/torchrec/fb/metrics/bucket_metric.py?lines=151-157) | both fuse modes raise exception |
|42 | fb | bucket_weighted_average_metric.py [2] | shouldn't allow fuse, [code](https://www.internalfb.com/code/fbsource/[42cb13ac942ea4d2cb504d051b35419ccc6760f8]/fbcode/torchrec/fb/metrics/bucket_weighted_average_metric.py?lines=303) | both fuse modes should raise exception |
|43 | MVAI | metrics.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/minimal_viable_ai/models/blue_reels_true_interest/metrics.py?lines=84) | |
|44 | MVAI | ndcg_metrics.py | | |
|45 | admarket | metrics.py | [2D or 3D?](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/admarket/targeting/lookalike_nextgen_trainer/lal_lr_trainer/utils/metrics.py?lines=157) | |
|46 | mrs/fm | metrics.py | | |

[1] There are 7 bucket metrics (bucket_calibration, bucket_ctr, bucket_hindsight_target_pr, bucket_mse, bucket_ne, bucket_precision, bucket_recall) inherit `BucketMetricComputation` defined in bucket_metric.py. All of them have 2D tensors and most shapes are (n_tasks, num_buckets). 

[2] There is 1 bucket metric (bucket_weighted_average_logloss) inherit `BucketWeightedAverageMetricComputation` defined in bucket_weighted_average_metric.py. All the state tensors are 2D (n_tasks, num_buckets).

Differential Revision: D73293593
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 19, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73293593

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants