Skip to content

Commit f2791cd

Browse files
ge0405facebook-github-bot
authored andcommitted
Fix backward compatibility for rec metric fuse states issue (#2894)
Summary: Pull Request resolved: #2894 `Metric` seems to be wrapped by packages. In some tests, `rec_metric` may import the old version `Metric` that doesn't have `fuse_tensor_states` attribute. This caused breakage: https://www.internalfb.com/diff/D72010614?dst_version_fbid=1727492137979220&transaction_fbid=1212653170466199 in D72010614. So this diff fixes that by examining the signature of `Metric` and only input `fuse_tensor_states` when it's there. Reviewed By: iamzainhuda Differential Revision: D73217574 fbshipit-source-id: 253fc6ac304baaa57ab8f6269b7b4f5845515126
1 parent 8819393 commit f2791cd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchrec/metrics/rec_metric.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#!/usr/bin/env python3
1111

1212
import abc
13+
import inspect
1314
import itertools
1415
import math
1516
from collections import defaultdict, deque
@@ -141,9 +142,11 @@ def __init__(
141142
*args: Any,
142143
**kwargs: Any,
143144
) -> None:
145+
metric_init_signature = inspect.signature(Metric.__init__)
146+
if "fuse_state_tensors" in metric_init_signature.parameters:
147+
kwargs["fuse_state_tensors"] = fuse_state_tensors
144148
super().__init__(
145149
process_group=process_group,
146-
fuse_state_tensors=fuse_state_tensors,
147150
*args,
148151
**kwargs,
149152
)

0 commit comments

Comments
 (0)