diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 2a5148153..9631bf811 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -68,6 +68,7 @@ from torchrec.metrics.unweighted_ne import UnweightedNEMetric from torchrec.metrics.weighted_avg import WeightedAvgMetric from torchrec.metrics.xauc import XAUCMetric +from torchrec.utils.experimental import experimental logger: logging.Logger = logging.getLogger(__name__) @@ -394,6 +395,7 @@ def _get_metric_states( return state_aggregated + @experimental def get_pre_compute_states( self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None ) -> Dict[str, Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]]: @@ -442,6 +444,7 @@ def get_pre_compute_states( return aggregated_states + @experimental def load_pre_compute_states( self, source: Dict[ diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index 5e2efc3bc..5c7a7e520 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -643,7 +643,7 @@ def metric_module_gather_state( metric_module.update(test_batch) computed_value = metric_module.compute() - states = metric_module.get_pre_compute_states(pg=ctx.pg) # pyre-ignore[6] + states = metric_module.get_pre_compute_states(pg=ctx.pg) torch.distributed.barrier(ctx.pg) # Compare to computing metrics on metric module that loads from pre_compute_states