Skip to content

Commit 7619498

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Futures support with cross-tensor attribution 2/n
Summary: Copypasta, refactor in next diff Differential Revision: D73466780
1 parent 277bb33 commit 7619498

File tree

1 file changed

+308
-2
lines changed

1 file changed

+308
-2
lines changed

captum/attr/_core/feature_ablation.py

+308-2
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,19 @@ def attribute_future(
791791
)
792792

793793
if enable_cross_tensor_attribution:
794-
raise NotImplementedError("Not supported yet")
794+
# pyre-fixme[7]: Expected`` Future[Variable[TensorOrTupleOfTensorsGeneric <:
795+
# [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
796+
return self._attribute_with_cross_tensor_feature_masks_future(
797+
formatted_inputs=formatted_inputs,
798+
formatted_additional_forward_args=formatted_additional_forward_args,
799+
target=target,
800+
baselines=baselines,
801+
formatted_feature_mask=formatted_feature_mask,
802+
attr_progress=attr_progress,
803+
processed_initial_eval_fut=processed_initial_eval_fut,
804+
is_inputs_tuple=is_inputs_tuple,
805+
perturbations_per_eval=perturbations_per_eval,
806+
)
795807
else:
796808
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797809
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
@@ -921,6 +933,213 @@ def _attribute_with_independent_feature_masks_future(
921933

922934
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
923935

936+
def _attribute_with_cross_tensor_feature_masks_future(
937+
self,
938+
formatted_inputs: Tuple[Tensor, ...],
939+
formatted_additional_forward_args: Optional[Tuple[object, ...]],
940+
target: TargetType,
941+
baselines: BaselineType,
942+
formatted_feature_mask: Tuple[Tensor, ...],
943+
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
944+
processed_initial_eval_fut: Future[
945+
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
946+
],
947+
is_inputs_tuple: bool,
948+
perturbations_per_eval: int,
949+
**kwargs: Any,
950+
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
951+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
952+
for i, mask in enumerate(formatted_feature_mask):
953+
for feature_idx in torch.unique(mask):
954+
if feature_idx.item() not in feature_idx_to_tensor_idx:
955+
feature_idx_to_tensor_idx[feature_idx.item()] = []
956+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
957+
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
958+
959+
additional_args_repeated: object
960+
if perturbations_per_eval > 1:
961+
# Repeat features and additional args for batch size.
962+
all_features_repeated = tuple(
963+
torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0)
964+
for j in range(len(formatted_inputs))
965+
)
966+
additional_args_repeated = (
967+
_expand_additional_forward_args(
968+
formatted_additional_forward_args, perturbations_per_eval
969+
)
970+
if formatted_additional_forward_args is not None
971+
else None
972+
)
973+
target_repeated = _expand_target(target, perturbations_per_eval)
974+
else:
975+
all_features_repeated = formatted_inputs
976+
additional_args_repeated = formatted_additional_forward_args
977+
target_repeated = target
978+
num_examples = formatted_inputs[0].shape[0]
979+
980+
current_additional_args: object
981+
if isinstance(baselines, tuple):
982+
reshaped = False
983+
reshaped_baselines: list[Union[Tensor, int, float]] = []
984+
for baseline in baselines:
985+
if isinstance(baseline, Tensor):
986+
reshaped = True
987+
reshaped_baselines.append(
988+
baseline.reshape((1,) + tuple(baseline.shape))
989+
)
990+
else:
991+
reshaped_baselines.append(baseline)
992+
baselines = tuple(reshaped_baselines) if reshaped else baselines
993+
994+
all_modified_eval_futures: List[Future[Tuple[List[Tensor], List[Tensor]]]] = []
995+
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
996+
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
997+
current_num_ablated_features = min(
998+
perturbations_per_eval, len(current_feature_idxs)
999+
)
1000+
1001+
should_skip = False
1002+
all_empty = True
1003+
tensor_idx_list = []
1004+
for feature_idx in current_feature_idxs:
1005+
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
1006+
for tensor_idx in set(tensor_idx_list):
1007+
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
1008+
all_empty = False
1009+
if self._min_examples_per_batch_grouped is not None and (
1010+
formatted_inputs[tensor_idx].shape[0]
1011+
# pyre-ignore[58]: Type has been narrowed to int
1012+
< self._min_examples_per_batch_grouped
1013+
):
1014+
should_skip = True
1015+
break
1016+
if all_empty:
1017+
logger.info(
1018+
f"Skipping feature group {current_feature_idxs} since all "
1019+
f"input tensors are empty"
1020+
)
1021+
continue
1022+
1023+
if should_skip:
1024+
logger.warning(
1025+
f"Skipping feature group {current_feature_idxs} since it contains "
1026+
f"at least one input tensor with 0th dim less than "
1027+
f"{self._min_examples_per_batch_grouped}"
1028+
)
1029+
continue
1030+
1031+
# Store appropriate inputs and additional args based on batch size.
1032+
if current_num_ablated_features != perturbations_per_eval:
1033+
current_additional_args = (
1034+
_expand_additional_forward_args(
1035+
formatted_additional_forward_args, current_num_ablated_features
1036+
)
1037+
if formatted_additional_forward_args is not None
1038+
else None
1039+
)
1040+
current_target = _expand_target(target, current_num_ablated_features)
1041+
expanded_inputs = tuple(
1042+
feature_repeated[0 : current_num_ablated_features * num_examples]
1043+
for feature_repeated in all_features_repeated
1044+
)
1045+
else:
1046+
current_additional_args = additional_args_repeated
1047+
current_target = target_repeated
1048+
expanded_inputs = all_features_repeated
1049+
1050+
current_inputs, current_masks = (
1051+
self._construct_ablated_input_across_tensors(
1052+
expanded_inputs,
1053+
formatted_feature_mask,
1054+
baselines,
1055+
current_feature_idxs,
1056+
feature_idx_to_tensor_idx,
1057+
current_num_ablated_features,
1058+
)
1059+
)
1060+
1061+
# modified_eval has (n_feature_perturbed * n_outputs) elements
1062+
# shape:
1063+
# agg mode: (*initial_eval.shape)
1064+
# non-agg mode:
1065+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
1066+
modified_eval = _run_forward(
1067+
self.forward_func,
1068+
current_inputs,
1069+
current_target,
1070+
current_additional_args,
1071+
)
1072+
1073+
if attr_progress is not None:
1074+
attr_progress.update()
1075+
1076+
if not isinstance(modified_eval, torch.Future):
1077+
raise AssertionError(
1078+
"when using attribute_future, modified_eval should have "
1079+
f"Future type rather than {type(modified_eval)}"
1080+
)
1081+
1082+
# Need to collect both initial eval and modified_eval
1083+
eval_futs: Future[
1084+
List[
1085+
Future[
1086+
Union[
1087+
Tuple[
1088+
List[Tensor],
1089+
List[Tensor],
1090+
Tensor,
1091+
Tensor,
1092+
int,
1093+
dtype,
1094+
],
1095+
Tensor,
1096+
]
1097+
]
1098+
]
1099+
] = collect_all(
1100+
[
1101+
processed_initial_eval_fut,
1102+
modified_eval,
1103+
]
1104+
)
1105+
1106+
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = eval_futs.then(
1107+
lambda eval_futs, current_inputs=current_inputs, current_mask=current_masks, i=i: self._eval_fut_to_ablated_out_fut_cross_tensor(
1108+
eval_futs=eval_futs,
1109+
current_inputs=current_inputs,
1110+
current_mask=current_mask,
1111+
perturbations_per_eval=perturbations_per_eval,
1112+
num_examples=num_examples,
1113+
)
1114+
)
1115+
1116+
all_modified_eval_futures.append(ablated_out_fut)
1117+
1118+
if attr_progress is not None:
1119+
attr_progress.close()
1120+
1121+
return self._generate_async_result_cross_tensor(
1122+
all_modified_eval_futures,
1123+
is_inputs_tuple,
1124+
)
1125+
1126+
def _fut_tuple_to_accumulate_fut_list_cross_tensor(
1127+
self,
1128+
total_attrib: List[Tensor],
1129+
weights: List[Tensor],
1130+
fut_tuple: Future[Tuple[List[Tensor], List[Tensor]]],
1131+
) -> None:
1132+
try:
1133+
# process_ablated_out_* already accumlates the total attribution.
1134+
# Just get the latest value
1135+
attribs, this_weights = fut_tuple.value()
1136+
total_attrib[:] = attribs
1137+
weights[:] = this_weights
1138+
except FeatureAblationFutureError as e:
1139+
raise FeatureAblationFutureError(
1140+
"_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
1141+
) from e
1142+
9241143
# pyre-fixme[3] return type must be annotated
9251144
def _attribute_progress_setup(
9261145
self,
@@ -950,7 +1169,6 @@ def _attribute_progress_setup(
9501169

9511170
def _eval_fut_to_ablated_out_fut(
9521171
self,
953-
# pyre-ignore Invalid type parameters [24]
9541172
eval_futs: Future[List[Future[List[object]]]],
9551173
current_inputs: Tuple[Tensor, ...],
9561174
current_mask: Tensor,
@@ -1012,6 +1230,94 @@ def _eval_fut_to_ablated_out_fut(
10121230
) from e
10131231
return result
10141232

1233+
def _generate_async_result_cross_tensor(
1234+
self,
1235+
futs: List[Future[Tuple[List[Tensor], List[Tensor]]]],
1236+
is_inputs_tuple: bool,
1237+
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
1238+
accumulate_fut_list: List[Future[None]] = []
1239+
total_attrib: List[Tensor] = []
1240+
weights: List[Tensor] = []
1241+
1242+
for fut_tuple in futs:
1243+
accumulate_fut_list.append(
1244+
fut_tuple.then(
1245+
lambda fut_tuple: self._fut_tuple_to_accumulate_fut_list_cross_tensor(
1246+
total_attrib, weights, fut_tuple
1247+
)
1248+
)
1249+
)
1250+
1251+
result_fut = collect_all(accumulate_fut_list).then(
1252+
lambda x: self._generate_result(
1253+
total_attrib,
1254+
weights,
1255+
is_inputs_tuple,
1256+
)
1257+
)
1258+
1259+
return result_fut
1260+
1261+
def _eval_fut_to_ablated_out_fut_cross_tensor(
1262+
self,
1263+
eval_futs: Future[List[Future[List[object]]]],
1264+
current_inputs: Tuple[Tensor, ...],
1265+
current_mask: Tuple[Optional[Tensor], ...],
1266+
perturbations_per_eval: int,
1267+
num_examples: int,
1268+
) -> Tuple[List[Tensor], List[Tensor]]:
1269+
try:
1270+
modified_eval = cast(Tensor, eval_futs.value()[1].value())
1271+
initial_eval_tuple = cast(
1272+
Tuple[
1273+
List[Tensor],
1274+
List[Tensor],
1275+
Tensor,
1276+
Tensor,
1277+
int,
1278+
dtype,
1279+
],
1280+
eval_futs.value()[0].value(),
1281+
)
1282+
if len(initial_eval_tuple) != 6:
1283+
raise AssertionError(
1284+
"eval_fut_to_ablated_out_fut_cross_tensor: "
1285+
"initial_eval_tuple should have 6 elements: "
1286+
"total_attrib, weights, initial_eval, "
1287+
"flattened_initial_eval, n_outputs, attrib_type "
1288+
)
1289+
if not isinstance(modified_eval, Tensor):
1290+
raise AssertionError(
1291+
"_eval_fut_to_ablated_out_fut_cross_tensor: "
1292+
"modified eval should be a Tensor"
1293+
)
1294+
(
1295+
total_attrib,
1296+
weights,
1297+
initial_eval,
1298+
flattened_initial_eval,
1299+
n_outputs,
1300+
attrib_type,
1301+
) = initial_eval_tuple
1302+
total_attrib, weights = self._process_ablated_out_full(
1303+
modified_eval=modified_eval,
1304+
inputs=current_inputs,
1305+
current_mask=current_mask,
1306+
perturbations_per_eval=perturbations_per_eval,
1307+
num_examples=num_examples,
1308+
initial_eval=initial_eval,
1309+
flattened_initial_eval=flattened_initial_eval,
1310+
n_outputs=n_outputs,
1311+
total_attrib=total_attrib,
1312+
weights=weights,
1313+
attrib_type=attrib_type,
1314+
)
1315+
except FeatureAblationFutureError as e:
1316+
raise FeatureAblationFutureError(
1317+
"_eval_fut_to_ablated_out_fut_cross_tensor func failed"
1318+
) from e
1319+
return total_attrib, weights
1320+
10151321
def _ith_input_ablation_generator(
10161322
self,
10171323
i: int,

0 commit comments

Comments
 (0)