@@ -791,7 +791,19 @@ def attribute_future(
791
791
)
792
792
793
793
if enable_cross_tensor_attribution :
794
- raise NotImplementedError ("Not supported yet" )
794
+ # pyre-fixme[7]: Expected`` Future[Variable[TensorOrTupleOfTensorsGeneric <:
795
+ # [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
796
+ return self ._attribute_with_cross_tensor_feature_masks_future (
797
+ formatted_inputs = formatted_inputs ,
798
+ formatted_additional_forward_args = formatted_additional_forward_args ,
799
+ target = target ,
800
+ baselines = baselines ,
801
+ formatted_feature_mask = formatted_feature_mask ,
802
+ attr_progress = attr_progress ,
803
+ processed_initial_eval_fut = processed_initial_eval_fut ,
804
+ is_inputs_tuple = is_inputs_tuple ,
805
+ perturbations_per_eval = perturbations_per_eval ,
806
+ )
795
807
else :
796
808
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797
809
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
@@ -921,6 +933,213 @@ def _attribute_with_independent_feature_masks_future(
921
933
922
934
return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
923
935
936
+ def _attribute_with_cross_tensor_feature_masks_future (
937
+ self ,
938
+ formatted_inputs : Tuple [Tensor , ...],
939
+ formatted_additional_forward_args : Optional [Tuple [object , ...]],
940
+ target : TargetType ,
941
+ baselines : BaselineType ,
942
+ formatted_feature_mask : Tuple [Tensor , ...],
943
+ attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
944
+ processed_initial_eval_fut : Future [
945
+ Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
946
+ ],
947
+ is_inputs_tuple : bool ,
948
+ perturbations_per_eval : int ,
949
+ ** kwargs : Any ,
950
+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
951
+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
952
+ for i , mask in enumerate (formatted_feature_mask ):
953
+ for feature_idx in torch .unique (mask ):
954
+ if feature_idx .item () not in feature_idx_to_tensor_idx :
955
+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
956
+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
957
+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
958
+
959
+ additional_args_repeated : object
960
+ if perturbations_per_eval > 1 :
961
+ # Repeat features and additional args for batch size.
962
+ all_features_repeated = tuple (
963
+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
964
+ for j in range (len (formatted_inputs ))
965
+ )
966
+ additional_args_repeated = (
967
+ _expand_additional_forward_args (
968
+ formatted_additional_forward_args , perturbations_per_eval
969
+ )
970
+ if formatted_additional_forward_args is not None
971
+ else None
972
+ )
973
+ target_repeated = _expand_target (target , perturbations_per_eval )
974
+ else :
975
+ all_features_repeated = formatted_inputs
976
+ additional_args_repeated = formatted_additional_forward_args
977
+ target_repeated = target
978
+ num_examples = formatted_inputs [0 ].shape [0 ]
979
+
980
+ current_additional_args : object
981
+ if isinstance (baselines , tuple ):
982
+ reshaped = False
983
+ reshaped_baselines : list [Union [Tensor , int , float ]] = []
984
+ for baseline in baselines :
985
+ if isinstance (baseline , Tensor ):
986
+ reshaped = True
987
+ reshaped_baselines .append (
988
+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
989
+ )
990
+ else :
991
+ reshaped_baselines .append (baseline )
992
+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
993
+
994
+ all_modified_eval_futures : List [Future [Tuple [List [Tensor ], List [Tensor ]]]] = []
995
+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
996
+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
997
+ current_num_ablated_features = min (
998
+ perturbations_per_eval , len (current_feature_idxs )
999
+ )
1000
+
1001
+ should_skip = False
1002
+ all_empty = True
1003
+ tensor_idx_list = []
1004
+ for feature_idx in current_feature_idxs :
1005
+ tensor_idx_list += feature_idx_to_tensor_idx [feature_idx ]
1006
+ for tensor_idx in set (tensor_idx_list ):
1007
+ if all_empty and torch .numel (formatted_inputs [tensor_idx ]) != 0 :
1008
+ all_empty = False
1009
+ if self ._min_examples_per_batch_grouped is not None and (
1010
+ formatted_inputs [tensor_idx ].shape [0 ]
1011
+ # pyre-ignore[58]: Type has been narrowed to int
1012
+ < self ._min_examples_per_batch_grouped
1013
+ ):
1014
+ should_skip = True
1015
+ break
1016
+ if all_empty :
1017
+ logger .info (
1018
+ f"Skipping feature group { current_feature_idxs } since all "
1019
+ f"input tensors are empty"
1020
+ )
1021
+ continue
1022
+
1023
+ if should_skip :
1024
+ logger .warning (
1025
+ f"Skipping feature group { current_feature_idxs } since it contains "
1026
+ f"at least one input tensor with 0th dim less than "
1027
+ f"{ self ._min_examples_per_batch_grouped } "
1028
+ )
1029
+ continue
1030
+
1031
+ # Store appropriate inputs and additional args based on batch size.
1032
+ if current_num_ablated_features != perturbations_per_eval :
1033
+ current_additional_args = (
1034
+ _expand_additional_forward_args (
1035
+ formatted_additional_forward_args , current_num_ablated_features
1036
+ )
1037
+ if formatted_additional_forward_args is not None
1038
+ else None
1039
+ )
1040
+ current_target = _expand_target (target , current_num_ablated_features )
1041
+ expanded_inputs = tuple (
1042
+ feature_repeated [0 : current_num_ablated_features * num_examples ]
1043
+ for feature_repeated in all_features_repeated
1044
+ )
1045
+ else :
1046
+ current_additional_args = additional_args_repeated
1047
+ current_target = target_repeated
1048
+ expanded_inputs = all_features_repeated
1049
+
1050
+ current_inputs , current_masks = (
1051
+ self ._construct_ablated_input_across_tensors (
1052
+ expanded_inputs ,
1053
+ formatted_feature_mask ,
1054
+ baselines ,
1055
+ current_feature_idxs ,
1056
+ feature_idx_to_tensor_idx ,
1057
+ current_num_ablated_features ,
1058
+ )
1059
+ )
1060
+
1061
+ # modified_eval has (n_feature_perturbed * n_outputs) elements
1062
+ # shape:
1063
+ # agg mode: (*initial_eval.shape)
1064
+ # non-agg mode:
1065
+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
1066
+ modified_eval = _run_forward (
1067
+ self .forward_func ,
1068
+ current_inputs ,
1069
+ current_target ,
1070
+ current_additional_args ,
1071
+ )
1072
+
1073
+ if attr_progress is not None :
1074
+ attr_progress .update ()
1075
+
1076
+ if not isinstance (modified_eval , torch .Future ):
1077
+ raise AssertionError (
1078
+ "when using attribute_future, modified_eval should have "
1079
+ f"Future type rather than { type (modified_eval )} "
1080
+ )
1081
+
1082
+ # Need to collect both initial eval and modified_eval
1083
+ eval_futs : Future [
1084
+ List [
1085
+ Future [
1086
+ Union [
1087
+ Tuple [
1088
+ List [Tensor ],
1089
+ List [Tensor ],
1090
+ Tensor ,
1091
+ Tensor ,
1092
+ int ,
1093
+ dtype ,
1094
+ ],
1095
+ Tensor ,
1096
+ ]
1097
+ ]
1098
+ ]
1099
+ ] = collect_all (
1100
+ [
1101
+ processed_initial_eval_fut ,
1102
+ modified_eval ,
1103
+ ]
1104
+ )
1105
+
1106
+ ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = eval_futs .then (
1107
+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_masks , i = i : self ._eval_fut_to_ablated_out_fut_cross_tensor (
1108
+ eval_futs = eval_futs ,
1109
+ current_inputs = current_inputs ,
1110
+ current_mask = current_mask ,
1111
+ perturbations_per_eval = perturbations_per_eval ,
1112
+ num_examples = num_examples ,
1113
+ )
1114
+ )
1115
+
1116
+ all_modified_eval_futures .append (ablated_out_fut )
1117
+
1118
+ if attr_progress is not None :
1119
+ attr_progress .close ()
1120
+
1121
+ return self ._generate_async_result_cross_tensor (
1122
+ all_modified_eval_futures ,
1123
+ is_inputs_tuple ,
1124
+ )
1125
+
1126
+ def _fut_tuple_to_accumulate_fut_list_cross_tensor (
1127
+ self ,
1128
+ total_attrib : List [Tensor ],
1129
+ weights : List [Tensor ],
1130
+ fut_tuple : Future [Tuple [List [Tensor ], List [Tensor ]]],
1131
+ ) -> None :
1132
+ try :
1133
+ # process_ablated_out_* already accumlates the total attribution.
1134
+ # Just get the latest value
1135
+ attribs , this_weights = fut_tuple .value ()
1136
+ total_attrib [:] = attribs
1137
+ weights [:] = this_weights
1138
+ except FeatureAblationFutureError as e :
1139
+ raise FeatureAblationFutureError (
1140
+ "_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
1141
+ ) from e
1142
+
924
1143
# pyre-fixme[3] return type must be annotated
925
1144
def _attribute_progress_setup (
926
1145
self ,
@@ -950,7 +1169,6 @@ def _attribute_progress_setup(
950
1169
951
1170
def _eval_fut_to_ablated_out_fut (
952
1171
self ,
953
- # pyre-ignore Invalid type parameters [24]
954
1172
eval_futs : Future [List [Future [List [object ]]]],
955
1173
current_inputs : Tuple [Tensor , ...],
956
1174
current_mask : Tensor ,
@@ -1012,6 +1230,94 @@ def _eval_fut_to_ablated_out_fut(
1012
1230
) from e
1013
1231
return result
1014
1232
1233
+ def _generate_async_result_cross_tensor (
1234
+ self ,
1235
+ futs : List [Future [Tuple [List [Tensor ], List [Tensor ]]]],
1236
+ is_inputs_tuple : bool ,
1237
+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
1238
+ accumulate_fut_list : List [Future [None ]] = []
1239
+ total_attrib : List [Tensor ] = []
1240
+ weights : List [Tensor ] = []
1241
+
1242
+ for fut_tuple in futs :
1243
+ accumulate_fut_list .append (
1244
+ fut_tuple .then (
1245
+ lambda fut_tuple : self ._fut_tuple_to_accumulate_fut_list_cross_tensor (
1246
+ total_attrib , weights , fut_tuple
1247
+ )
1248
+ )
1249
+ )
1250
+
1251
+ result_fut = collect_all (accumulate_fut_list ).then (
1252
+ lambda x : self ._generate_result (
1253
+ total_attrib ,
1254
+ weights ,
1255
+ is_inputs_tuple ,
1256
+ )
1257
+ )
1258
+
1259
+ return result_fut
1260
+
1261
+ def _eval_fut_to_ablated_out_fut_cross_tensor (
1262
+ self ,
1263
+ eval_futs : Future [List [Future [List [object ]]]],
1264
+ current_inputs : Tuple [Tensor , ...],
1265
+ current_mask : Tuple [Optional [Tensor ], ...],
1266
+ perturbations_per_eval : int ,
1267
+ num_examples : int ,
1268
+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
1269
+ try :
1270
+ modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
1271
+ initial_eval_tuple = cast (
1272
+ Tuple [
1273
+ List [Tensor ],
1274
+ List [Tensor ],
1275
+ Tensor ,
1276
+ Tensor ,
1277
+ int ,
1278
+ dtype ,
1279
+ ],
1280
+ eval_futs .value ()[0 ].value (),
1281
+ )
1282
+ if len (initial_eval_tuple ) != 6 :
1283
+ raise AssertionError (
1284
+ "eval_fut_to_ablated_out_fut_cross_tensor: "
1285
+ "initial_eval_tuple should have 6 elements: "
1286
+ "total_attrib, weights, initial_eval, "
1287
+ "flattened_initial_eval, n_outputs, attrib_type "
1288
+ )
1289
+ if not isinstance (modified_eval , Tensor ):
1290
+ raise AssertionError (
1291
+ "_eval_fut_to_ablated_out_fut_cross_tensor: "
1292
+ "modified eval should be a Tensor"
1293
+ )
1294
+ (
1295
+ total_attrib ,
1296
+ weights ,
1297
+ initial_eval ,
1298
+ flattened_initial_eval ,
1299
+ n_outputs ,
1300
+ attrib_type ,
1301
+ ) = initial_eval_tuple
1302
+ total_attrib , weights = self ._process_ablated_out_full (
1303
+ modified_eval = modified_eval ,
1304
+ inputs = current_inputs ,
1305
+ current_mask = current_mask ,
1306
+ perturbations_per_eval = perturbations_per_eval ,
1307
+ num_examples = num_examples ,
1308
+ initial_eval = initial_eval ,
1309
+ flattened_initial_eval = flattened_initial_eval ,
1310
+ n_outputs = n_outputs ,
1311
+ total_attrib = total_attrib ,
1312
+ weights = weights ,
1313
+ attrib_type = attrib_type ,
1314
+ )
1315
+ except FeatureAblationFutureError as e :
1316
+ raise FeatureAblationFutureError (
1317
+ "_eval_fut_to_ablated_out_fut_cross_tensor func failed"
1318
+ ) from e
1319
+ return total_attrib , weights
1320
+
1015
1321
def _ith_input_ablation_generator (
1016
1322
self ,
1017
1323
i : int ,
0 commit comments