Skip to content

Commit 9eaec09

Browse files
ge0405facebook-github-bot
authored andcommitted
Add warning/error for fused_tasks_and_states compute mode for 2D/List state tensors (#2899)
Summary: Pull Request resolved: #2899 I visited all child classes that use `RecMetricComputation` to see if any is incompatible with the added `FUSED_TASKS_AND_STATES_COMPUTATION` in D72010614. As of Apr 16, 2025, searching "`(RecMetricComputation`" in fbcode resulted in 47 results. - 1 is rec_metric in comment (not show in the following table) - 42 metric classes defined in torchrec, 29 in OSS, 13 in FB - 4 metric classes in customer codebase, e.g. MVAI, admarket (see row 43-46 in the table below) RecMetricComputation uses `state` tensors to compute/update results. I looked through all following metrics' state tensors and see if any if them are (1) not Tensor (e.g. List) or (2) 2D Tensor. If so, when these metrics use `FUSED_TASKS_AND_STATES_COMPUTATION`, the init should show warning (and simply FUSED_TASKS_COMPUTATION will be used) or raise Exception (if these metrics can't allow any fuse mode). | | dir | metric | type of state tensors [1] | tests/warning | |1 | oss | auc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/auc.py?lines=192 ) | warning | |2 | oss | tower_qps.py | | | |3 | oss | precision_session.py | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[390e95a4e30dfd1f2c7ad23b2de92b1b7edbdf15]/fbcode/torchrec/metrics/precision_session.py?lines=203) | both fuse modes raise exception | |4 | oss | serving_ne.py | | | |5 | oss | recall_session.py | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[390e95a4e30dfd1f2c7ad23b2de92b1b7edbdf15]/fbcode/torchrec/metrics/recall_session.py?lines=242) | both fuse modes raise exception | |6 | oss | calibration.py | | | |7 | oss | multiclass_recall.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/multiclass_recall.py?lines=99)| warning | |8 | oss | ndcg.py | | | |9 | oss | gauc.py | | | |10 | oss | tensor_weighted_avg.py | | | |11 | oss | ne.py | | | |12 | oss | serving_calibration.py | | | |13 | oss | segmented_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/segmented_ne.py?lines=187)| warning | |14 | oss | scalar.py | | | |15 | oss | mae.py | | | |16 | oss | ne_positive.py | | | |17 | oss | weighted_avg.py | | | |18 | oss | output.py | | | |20 | oss | cali_free_ne.py | | | |21 | oss | unweighted_ne.py | | | |22 | oss | hindsight_target_pr.py | 1D but not n_tasks, [code](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/hindsight_target_pr.py?lines=131) | can still fuse states | |23 | oss | rauc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/rauc.py?lines=238) | warning | |24 | oss | precision.py | | | |25 | oss | recall.py | | | |26 | oss | auprc.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/metrics/auprc.py?lines=186)| warning | |27 | oss | mse.py | | | |28 | oss | ctr.py | | | |29 | oss | accuracy.py | | | |30 | fb | log_normal_cnll.py | | | |31 | fb | coarse_grained_multiclass_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[8670f823b23ca84165106e7a7bc236d066a03c7d]/fbcode/torchrec/fb/metrics/coarse_grained_multiclass_ne.py?lines=62)| warning | |32 | fb | regression_huber.py | | | |33 | fb | modified_poisson_nll.py | | | |34 | fb | res_ne.py | | | |35 | fb | dist_shift.py | | | |36 | fb | serving_ne.py | | | |37 | fb | unjoined_calibration.py | | | |38 | fb | unjoined_ne.py | | | |39 | fb | serving_calibration.py | | | |40 | fb | multiclass_ne.py | [2D tensor](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/torchrec/fb/metrics/multiclass_ne.py?lines=148) | warning | |41 | fb | bucket_metric.py [2] | not allow fuse, [code](https://www.internalfb.com/code/fbsource/[0cc90263a7ae86e5327b153ccfe2f79b5956c69d]/fbcode/torchrec/fb/metrics/bucket_metric.py?lines=151-157) | both fuse modes raise exception | |42 | fb | bucket_weighted_average_metric.py [3] | shouldn't allow fuse, [code](https://www.internalfb.com/code/fbsource/[42cb13ac942ea4d2cb504d051b35419ccc6760f8]/fbcode/torchrec/fb/metrics/bucket_weighted_average_metric.py?lines=303) | both fuse modes should raise exception | |43 | MVAI | metrics.py | [List](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/minimal_viable_ai/models/blue_reels_true_interest/metrics.py?lines=84) | | |44 | MVAI | ndcg_metrics.py | | | |45 | admarket | metrics.py | [2D or 3D?](https://www.internalfb.com/code/fbsource/[0cd67a62f39734525b63e9fc054f9d169e48b793]/fbcode/admarket/targeting/lookalike_nextgen_trainer/lal_lr_trainer/utils/metrics.py?lines=157) | | |46 | mrs/fm | metrics.py | | | [1] For metrics that I don't specify "type of state tensors" in the above table, they are all 1D tensors with (n_tasks) shape. [2] There are 7 bucket metrics (bucket_calibration, bucket_ctr, bucket_hindsight_target_pr, bucket_mse, bucket_ne, bucket_precision, bucket_recall) inherit `BucketMetricComputation` defined in bucket_metric.py. All of them have 2D tensors and most shapes are (n_tasks, num_buckets). [3] There is 1 bucket metric (bucket_weighted_average_logloss) inherit `BucketWeightedAverageMetricComputation` defined in bucket_weighted_average_metric.py. All the state tensors are 2D (n_tasks, num_buckets). Reviewed By: iamzainhuda Differential Revision: D73293593 fbshipit-source-id: 4ad9dfafeed1171d63dc301f4dc7608c52d105b8
1 parent 72a124f commit 9eaec09

10 files changed

+155
-8
lines changed

torchrec/metrics/auc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import logging
1011
from functools import partial
1112
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
1213

@@ -23,6 +24,8 @@
2324
)
2425

2526

27+
logger: logging.Logger = logging.getLogger(__name__)
28+
2629
PREDICTIONS = "predictions"
2730
LABELS = "labels"
2831
WEIGHTS = "weights"
@@ -405,3 +408,8 @@ def __init__(
405408
)
406409
if kwargs.get("grouped_auc"):
407410
self._required_inputs.add(GROUPING_KEYS)
411+
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
412+
logging.warning(
413+
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
414+
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
415+
)

torchrec/metrics/auprc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import logging
1011
from functools import partial
1112
from typing import Any, cast, Dict, List, Optional, Type
1213

@@ -23,6 +24,8 @@
2324
)
2425

2526

27+
logger: logging.Logger = logging.getLogger(__name__)
28+
2629
PREDICTIONS = "predictions"
2730
LABELS = "labels"
2831
WEIGHTS = "weights"
@@ -361,3 +364,8 @@ def __init__(
361364
)
362365
if kwargs.get("grouped_auprc"):
363366
self._required_inputs.add(GROUPING_KEYS)
367+
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
368+
logging.warning(
369+
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
370+
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
371+
)

torchrec/metrics/metrics_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ class RecComputeMode(Enum):
8181
"""This Enum lists the supported computation modes for RecMetrics.
8282
8383
FUSED_TASKS_COMPUTATION indicates that RecMetrics will fuse the computation
84-
for multiple tasks of the same metric. This can be used by modules where the
85-
outputs of all the tasks are vectorized.
84+
for multiple tasks of the same metric. This can be used by modules where the
85+
outputs of all the tasks are vectorized.
86+
FUSED_TASKS_AND_STATES_COMPUTATION fuse both the tasks (same as FUSED_TASKS_COMPUTATION)
87+
and states (e.g. calibration_num and calibration_denom for caliration) of the
88+
same metric. This currently only supports 1D state tensors (e.g. when all state
89+
tensors are of the same (n_tasks) shape).
8690
"""
8791

8892
FUSED_TASKS_COMPUTATION = 1

torchrec/metrics/multiclass_recall.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
# pyre-strict
99

10+
import logging
1011
from typing import Any, cast, Dict, List, Optional, Type
1112

1213
import torch
14+
from torchrec.metrics.metrics_config import RecComputeMode
1315
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1416

1517
from torchrec.metrics.rec_metric import (
@@ -20,6 +22,9 @@
2022
)
2123

2224

25+
logger: logging.Logger = logging.getLogger(__name__)
26+
27+
2328
def compute_true_positives_at_k(
2429
predictions: torch.Tensor,
2530
labels: torch.Tensor,
@@ -154,3 +159,11 @@ def _compute(self) -> List[MetricComputationReport]:
154159
class MulticlassRecallMetric(RecMetric):
155160
_namespace: MetricNamespace = MetricNamespace.MULTICLASS_RECALL
156161
_computation_class: Type[RecMetricComputation] = MulticlassRecallMetricComputation
162+
163+
def __init__(self, *args: Any, **kwargs: Any) -> None:
164+
super().__init__(*args, **kwargs)
165+
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
166+
logging.warning(
167+
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
168+
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
169+
)

torchrec/metrics/rauc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import logging
1011
from functools import partial
1112
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
1213

@@ -23,6 +24,8 @@
2324
)
2425

2526

27+
logger: logging.Logger = logging.getLogger(__name__)
28+
2629
PREDICTIONS = "predictions"
2730
LABELS = "labels"
2831
WEIGHTS = "weights"
@@ -448,3 +451,8 @@ def __init__(
448451
)
449452
if kwargs.get("grouped_rauc"):
450453
self._required_inputs.add(GROUPING_KEYS)
454+
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
455+
logging.warning(
456+
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
457+
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
458+
)

torchrec/metrics/segmented_ne.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import logging
1011
from typing import Any, Dict, List, Optional, Type
1112

1213
import torch
@@ -21,6 +22,8 @@
2122
)
2223

2324

25+
logger: logging.Logger = logging.getLogger(__name__)
26+
2427
PREDICTIONS = "predictions"
2528
LABELS = "labels"
2629
WEIGHTS = "weights"
@@ -346,3 +349,8 @@ def __init__(
346349
else:
347350
# pyre-ignore[6]
348351
self._required_inputs.add(kwargs["grouping_keys"])
352+
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
353+
logging.warning(
354+
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
355+
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
356+
)

torchrec/metrics/tests/test_mae.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class MAEMetricTest(unittest.TestCase):
4949
clazz: Type[RecMetric] = MAEMetric
5050
task_name: str = "mae"
5151

52-
def test_unfused_mae(self) -> None:
52+
def test_mae_unfused(self) -> None:
5353
rec_metric_value_test_launcher(
5454
target_clazz=MAEMetric,
5555
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
@@ -63,7 +63,7 @@ def test_unfused_mae(self) -> None:
6363
entry_point=metric_test_helper,
6464
)
6565

66-
def test_fused_mae(self) -> None:
66+
def test_mae_fused_tasks(self) -> None:
6767
rec_metric_value_test_launcher(
6868
target_clazz=MAEMetric,
6969
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -77,6 +77,20 @@ def test_fused_mae(self) -> None:
7777
entry_point=metric_test_helper,
7878
)
7979

80+
def test_mae_fused_tasks_and_states(self) -> None:
81+
rec_metric_value_test_launcher(
82+
target_clazz=MAEMetric,
83+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
84+
test_clazz=TestMAEMetric,
85+
metric_name="mae",
86+
task_names=["t1", "t2", "t3"],
87+
fused_update_limit=0,
88+
compute_on_all_ranks=False,
89+
should_validate_update=False,
90+
world_size=WORLD_SIZE,
91+
entry_point=metric_test_helper,
92+
)
93+
8094

8195
class MAEGPUSyncTest(unittest.TestCase):
8296
clazz: Type[RecMetric] = MAEMetric

torchrec/metrics/tests/test_precision_session.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
import torch
1414
from torch import no_grad
15-
from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef
1615

16+
from torchrec.metrics.metrics_config import (
17+
RecComputeMode,
18+
RecTaskInfo,
19+
SessionMetricDef,
20+
)
1721
from torchrec.metrics.precision_session import PrecisionSessionMetric
1822
from torchrec.metrics.rec_metric import RecMetricException
1923

@@ -234,6 +238,37 @@ def test_error_messages(self) -> None:
234238
tasks=[task_info2],
235239
)
236240

241+
def test_compute_mode_exception(self) -> None:
242+
task_info = RecTaskInfo(
243+
name="Task1",
244+
label_name="label1",
245+
prediction_name="prediction1",
246+
weight_name="weight1",
247+
)
248+
with self.assertRaisesRegex(
249+
RecMetricException,
250+
"Fused computation is not supported for precision session-level metrics",
251+
):
252+
PrecisionSessionMetric(
253+
world_size=1,
254+
my_rank=0,
255+
batch_size=100,
256+
tasks=[task_info],
257+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
258+
)
259+
260+
with self.assertRaisesRegex(
261+
RecMetricException,
262+
"Fused computation is not supported for precision session-level metrics",
263+
):
264+
PrecisionSessionMetric(
265+
world_size=1,
266+
my_rank=5,
267+
batch_size=100,
268+
tasks=[task_info],
269+
compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
270+
)
271+
237272
def test_tasks_input_propagation(self) -> None:
238273
task_info1 = RecTaskInfo(
239274
name="Task1",

torchrec/metrics/tests/test_recall_session.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import torch
1414
from torch import no_grad
15-
from torchrec.metrics.metrics_config import RecTaskInfo, SessionMetricDef
15+
from torchrec.metrics.metrics_config import (
16+
RecComputeMode,
17+
RecTaskInfo,
18+
SessionMetricDef,
19+
)
1620
from torchrec.metrics.rec_metric import RecMetricException
1721

1822
from torchrec.metrics.recall_session import RecallSessionMetric
@@ -243,6 +247,37 @@ def test_error_messages(self) -> None:
243247
tasks=[task_info2],
244248
)
245249

250+
def test_compute_mode_exception(self) -> None:
251+
task_info = RecTaskInfo(
252+
name="Task1",
253+
label_name="label1",
254+
prediction_name="prediction1",
255+
weight_name="weight1",
256+
)
257+
with self.assertRaisesRegex(
258+
RecMetricException,
259+
"Fused computation is not supported for recall session-level metrics",
260+
):
261+
RecallSessionMetric(
262+
world_size=1,
263+
my_rank=0,
264+
batch_size=100,
265+
tasks=[task_info],
266+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
267+
)
268+
269+
with self.assertRaisesRegex(
270+
RecMetricException,
271+
"Fused computation is not supported for recall session-level metrics",
272+
):
273+
RecallSessionMetric(
274+
world_size=1,
275+
my_rank=5,
276+
batch_size=100,
277+
tasks=[task_info],
278+
compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
279+
)
280+
246281
def test_tasks_input_propagation(self) -> None:
247282
task_info1 = RecTaskInfo(
248283
name="Task1",

torchrec/metrics/tests/test_serving_ne.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_ne_unfused(self) -> None:
7878
entry_point=metric_test_helper,
7979
)
8080

81-
def test_ne_fused(self) -> None:
81+
def test_ne_fused_tasks(self) -> None:
8282
rec_metric_value_test_launcher(
8383
target_clazz=ServingNEMetric,
8484
target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -92,7 +92,21 @@ def test_ne_fused(self) -> None:
9292
entry_point=metric_test_helper,
9393
)
9494

95-
def test_ne_update_fused(self) -> None:
95+
def test_ne_fused_tasks_and_states(self) -> None:
96+
rec_metric_value_test_launcher(
97+
target_clazz=ServingNEMetric,
98+
target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
99+
test_clazz=TestNEMetric,
100+
metric_name=ServingNEMetricTest.task_name,
101+
task_names=["t1", "t2", "t3"],
102+
fused_update_limit=0,
103+
compute_on_all_ranks=False,
104+
should_validate_update=False,
105+
world_size=WORLD_SIZE,
106+
entry_point=metric_test_helper,
107+
)
108+
109+
def test_ne_update_unfused(self) -> None:
96110
rec_metric_value_test_launcher(
97111
target_clazz=ServingNEMetric,
98112
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,

0 commit comments

Comments
 (0)