Skip to content

Commit f08e025

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Add enable_cross_tensor_attribution flag to attribute_future
Summary: reserved Differential Revision: D73464680
1 parent 5248929 commit f08e025

File tree

1 file changed

+113
-77
lines changed

1 file changed

+113
-77
lines changed

captum/attr/_core/feature_ablation.py

+113-77
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def attribute_future(
729729
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
730730
perturbations_per_eval: int = 1,
731731
show_progress: bool = False,
732+
enable_cross_tensor_attribution: bool = False,
732733
**kwargs: Any,
733734
) -> Future[TensorOrTupleOfTensorsGeneric]:
734735
r"""
@@ -743,17 +744,18 @@ def attribute_future(
743744
formatted_additional_forward_args = _format_additional_forward_args(
744745
additional_forward_args
745746
)
746-
num_examples = formatted_inputs[0].shape[0]
747747
formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs)
748748

749749
assert (
750750
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
751751
), "Perturbations per evaluation must be an integer and at least 1."
752752
with torch.no_grad():
753+
attr_progress = None
753754
if show_progress:
754755
attr_progress = self._attribute_progress_setup(
755756
formatted_inputs,
756757
formatted_feature_mask,
758+
enable_cross_tensor_attribution,
757759
**kwargs,
758760
perturbations_per_eval=perturbations_per_eval,
759761
)
@@ -788,101 +790,135 @@ def attribute_future(
788790
)
789791
)
790792

791-
# The will be the same amount futures as modified_eval down there,
792-
# since we cannot add up the evaluation result adhoc under async mode.
793-
all_modified_eval_futures: List[
794-
List[Future[Tuple[List[Tensor], List[Tensor]]]]
795-
] = [[] for _ in range(len(inputs))]
796-
# Iterate through each feature tensor for ablation
797-
for i in range(len(formatted_inputs)):
798-
# Skip any empty input tensors
799-
if torch.numel(formatted_inputs[i]) == 0:
800-
continue
801-
802-
for (
803-
current_inputs,
804-
current_add_args,
805-
current_target,
806-
current_mask,
807-
) in self._ith_input_ablation_generator(
808-
i,
793+
if enable_cross_tensor_attribution:
794+
raise NotImplementedError("Not supported yet")
795+
else:
796+
# pyre-fixme[7]: Expected`` Future[Variable[TensorOrTupleOfTensorsGeneric <:
797+
# [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
798+
return self._attribute_with_independent_feature_masks_future(
809799
formatted_inputs,
810800
formatted_additional_forward_args,
811801
target,
812802
baselines,
813803
formatted_feature_mask,
814804
perturbations_per_eval,
805+
attr_progress,
806+
processed_initial_eval_fut,
807+
is_inputs_tuple,
815808
**kwargs,
816-
):
817-
# modified_eval has (n_feature_perturbed * n_outputs) elements
818-
# shape:
819-
# agg mode: (*initial_eval.shape)
820-
# non-agg mode:
821-
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
822-
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
823-
self.forward_func,
824-
current_inputs,
825-
current_target,
826-
current_add_args,
827-
)
809+
)
828810

829-
if show_progress:
830-
attr_progress.update()
811+
def _attribute_with_independent_feature_masks_future(
812+
self,
813+
formatted_inputs: Tuple[Tensor, ...],
814+
formatted_additional_forward_args: Optional[Tuple[object, ...]],
815+
target: TargetType,
816+
baselines: BaselineType,
817+
formatted_feature_mask: Tuple[Tensor, ...],
818+
perturbations_per_eval: int,
819+
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
820+
processed_initial_eval_fut: Future[
821+
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
822+
],
823+
is_inputs_tuple: bool,
824+
**kwargs: Any,
825+
) -> Future[Tensor | Tuple[Tensor, ...]]:
826+
num_examples = formatted_inputs[0].shape[0]
827+
# The will be the same amount futures as modified_eval down there,
828+
# since we cannot add up the evaluation result adhoc under async mode.
829+
all_modified_eval_futures: List[
830+
List[Future[Tuple[List[Tensor], List[Tensor]]]]
831+
] = [[] for _ in range(len(formatted_inputs))]
832+
# Iterate through each feature tensor for ablation
833+
for i in range(len(formatted_inputs)):
834+
# Skip any empty input tensors
835+
if torch.numel(formatted_inputs[i]) == 0:
836+
continue
831837

832-
if not isinstance(modified_eval, torch.Future):
833-
raise AssertionError(
834-
"when using attribute_future, modified_eval should have "
835-
f"Future type rather than {type(modified_eval)}"
836-
)
837-
if processed_initial_eval_fut is None:
838-
raise AssertionError(
839-
"processed_initial_eval_fut should not be None"
840-
)
838+
for (
839+
current_inputs,
840+
current_add_args,
841+
current_target,
842+
current_mask,
843+
) in self._ith_input_ablation_generator(
844+
i,
845+
formatted_inputs,
846+
formatted_additional_forward_args,
847+
target,
848+
baselines,
849+
formatted_feature_mask,
850+
perturbations_per_eval,
851+
**kwargs,
852+
):
853+
# modified_eval has (n_feature_perturbed * n_outputs) elements
854+
# shape:
855+
# agg mode: (*initial_eval.shape)
856+
# non-agg mode:
857+
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
858+
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
859+
self.forward_func,
860+
current_inputs,
861+
current_target,
862+
current_add_args,
863+
)
864+
865+
if attr_progress is not None:
866+
attr_progress.update()
867+
868+
if not isinstance(modified_eval, torch.Future):
869+
raise AssertionError(
870+
"when using attribute_future, modified_eval should have "
871+
f"Future type rather than {type(modified_eval)}"
872+
)
873+
if processed_initial_eval_fut is None:
874+
raise AssertionError(
875+
"processed_initial_eval_fut should not be None"
876+
)
841877

842-
# Need to collect both initial eval and modified_eval
843-
eval_futs: Future[
844-
List[
845-
Future[
846-
Union[
847-
Tuple[
848-
List[Tensor],
849-
List[Tensor],
850-
Tensor,
851-
Tensor,
852-
int,
853-
dtype,
854-
],
878+
# Need to collect both initial eval and modified_eval
879+
eval_futs: Future[
880+
List[
881+
Future[
882+
Union[
883+
Tuple[
884+
List[Tensor],
885+
List[Tensor],
886+
Tensor,
855887
Tensor,
856-
]
888+
int,
889+
dtype,
890+
],
891+
Tensor,
857892
]
858893
]
859-
] = collect_all(
860-
[
861-
processed_initial_eval_fut,
862-
modified_eval,
863-
]
864-
)
894+
]
895+
] = collect_all(
896+
[
897+
processed_initial_eval_fut,
898+
modified_eval,
899+
]
900+
)
865901

866-
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
867-
eval_futs.then(
868-
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
869-
eval_futs=eval_futs,
870-
current_inputs=current_inputs,
871-
current_mask=current_mask,
872-
i=i,
873-
perturbations_per_eval=perturbations_per_eval,
874-
num_examples=num_examples,
875-
formatted_inputs=formatted_inputs,
876-
)
902+
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
903+
eval_futs.then(
904+
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
905+
eval_futs=eval_futs,
906+
current_inputs=current_inputs,
907+
current_mask=current_mask,
908+
i=i,
909+
perturbations_per_eval=perturbations_per_eval,
910+
num_examples=num_examples,
911+
formatted_inputs=formatted_inputs,
877912
)
878913
)
914+
)
879915

880-
all_modified_eval_futures[i].append(ablated_out_fut)
916+
all_modified_eval_futures[i].append(ablated_out_fut)
881917

882-
if show_progress:
883-
attr_progress.close()
918+
if attr_progress is not None:
919+
attr_progress.close()
884920

885-
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
921+
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
886922

887923
# pyre-fixme[3] return type must be annotated
888924
def _attribute_progress_setup(

0 commit comments

Comments
 (0)