Skip to content

Commit 8af6e8b

Browse files
committed
First commit.
0 parents  commit 8af6e8b

File tree

272 files changed

+296342
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

272 files changed

+296342
-0
lines changed

README.md

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Time Can Invalidate Algorithmic Recourse
2+
3+
This is the source code for the paper "Time Can Invalidate Algorithmic Recourse". The implementation is based on the code of Dominguez-Olmedo et al. (2022), available [here](https://github.com/RicardoDominguez/AdversariallyRobustRecourse).
4+
5+
## Structure
6+
7+
The project repository is organized with the following directories:
8+
- `analytics/`: it contains scripts to analyze the results and generate the figures of the paper.
9+
- `data/`: it contains the data for the experiments and the suitable pre-trained models.
10+
- `learned_scms/`: it contains the pre-trained approximate SCMs for the experiments with realistic data (see Section 4.2 of the paper).
11+
- `results/`: it contains all the raw results from the experiments.
12+
- `src/`: it contains the working code e.g., SCMs, implementations of TSAR, CAR, SAR and IMF, etc.
13+
- `testing/`: it contains unit-tests checking some basic implementation details.
14+
15+
## Install
16+
17+
We report here the instructions to install a `conda` environment with Python 3.10 to run the experiments. We also report the needed packages. In theory, it should not matter which environment manager is used.
18+
19+
```bash
20+
# Create a suitable environment with python 3.10
21+
conda create --name temporal-recourse python=3.10
22+
conda activate temporal-recourse
23+
24+
# Install the required libraries into the environment
25+
conda install pytorch torchvision torchaudio cpuonly -c pytorch
26+
pip install numpy scipy matplotlib scikit-learn pandas seaborn tqdm cvxpy mpi4py
27+
```
28+
29+
Once the environment is set up, we suggest to run the unittest to ensure everything is in order:
30+
31+
```bash
32+
python -m unittest testing/test_temporal_causal_models.py
33+
```
34+
35+
## Reproduce the experiments
36+
37+
We prepared several bash scripts to run automatically all the experiments presented in the main paper.
38+
Beware that it might take some time since we run exclusively on CPU.
39+
Before running any of these scripts, make sure to be in the parent directory and to have run:
40+
41+
```bash
42+
cd temporal-recourse-submission
43+
conda activate temporal-recourse
44+
export PYTHONPATH=.
45+
```
46+
47+
The following will run the experiments detailed in Section 4.1.
48+
49+
```bash
50+
bash run_proposition_1.sh
51+
bash run_task_1.sh
52+
```
53+
54+
The following will instead run the experiments detailed in Section 4.2 and Appendix D, respectively.
55+
56+
```bash
57+
bash run_task_2.sh
58+
bash run_task_2_ground_truth.sh
59+
```
60+
61+
Lastly, we can run the following to regenerate the plot in Figure 2.
62+
63+
```bash
64+
bash run_alpha.sh
65+
```
66+
67+
### Re-train the generative models
68+
69+
We also provide some scripts to retrain the generative model used in Section 4.2 of the main paper. For more details, please have a look at Appendix B. In practice, this step is not needed since we provide the pre-trained models in `learned_scms/`.
70+
71+
```bash
72+
bash run_train_scms_syn.sh
73+
bash run_train_scms.sh
74+
```
75+
76+
## Re-generate the plots
77+
78+
Given the raw data files in `results/`, we provide some scripts and bash commands to generate all the figures of the main paper.
79+
80+
```bash
81+
# Generate Figure 2
82+
python analytics/plot_alphas.py results/task_0_alpha/*.csv
83+
84+
# Generate Figure 3
85+
python analytics/plot_test_1.py results/task_1_uncertainty/dnn_*.csv
86+
87+
# Generate Figure 4
88+
python analytics/plot_trends_syn.py results/task_2_syn_data/*.csv
89+
90+
# Generate Figure 5
91+
python analytics/plot_trends.py results/appendix/task_3_real_linear/*.csv
92+
bash plots/analysis_actions.sh
93+
94+
# Generate plots in Appendix C
95+
bash plots/cost_actions.sh
96+
bash plots/sparsity_actions.sh
97+
98+
# Generate plots in Appendix D
99+
python analytics/plot_trends.py results/task_3_real_data_ground_truth/*.csv
100+
```

analytics/plot_actions.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import numpy as np
2+
import pandas as pd
3+
import matplotlib.pyplot as plt
4+
import seaborn as sns
5+
6+
import pickle
7+
8+
import argparse
9+
import os
10+
11+
from src.utils import PALETTE
12+
13+
sns.set(font_scale=1.5,
14+
style="ticks",
15+
rc={
16+
"text.usetex": True,
17+
'text.latex.preamble': r'\usepackage{amsfonts} \usepackage{amsmath} \usepackage{bm}',
18+
"font.family": "serif",
19+
})
20+
21+
technique_names = {
22+
"CFR": r"\texttt{CAR}",
23+
"SPR": r"\texttt{SAR}",
24+
"IMF": r"\texttt{IMF}",
25+
"robust_time": r"\texttt{T-SAR}"
26+
}
27+
28+
groups = {
29+
"Loan": [[5], [6], [5,6]],
30+
"COMPAS": [[3]],
31+
"Adult": [[4], [5], [4,5]],
32+
"Linear ANM": [[0],[1],[2], [0,1], [0,2], [1,2], [0,1,2]],
33+
"Non-Linear ANM": [[0],[1],[2], [0,1], [0,2], [1,2], [0,1,2]]
34+
}
35+
36+
rename_groups = {
37+
"Loan": {
38+
"[5]": r"$\{income\}$",
39+
"[6]": r"$\{savings\}$",
40+
"[5, 6]": r"$\{income, savings\}$",
41+
}
42+
}
43+
44+
45+
if __name__ == "__main__":
46+
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument("file", nargs="+", type=str, help="CSV containing the results")
49+
args = parser.parse_args()
50+
51+
# Where to save all the data
52+
all_data = []
53+
all_dfs = []
54+
55+
for f in args.file:
56+
57+
data = pickle.load(open(f, "rb"))
58+
59+
experiment, classifier, scm, trend, _, alpha, runs, mc_samples = os.path.basename(f).split("_")[0:8]
60+
61+
if scm == "non-linear":
62+
data["scm"] = "Non-Linear ANM"
63+
elif scm == "linear":
64+
data["scm"] = "Linear ANM"
65+
elif scm == "adult":
66+
data["scm"] = "Adult"
67+
elif scm == "compas":
68+
data["scm"] = "COMPAS"
69+
else:
70+
data["scm"] = "Loan"
71+
72+
data["classifier"] = "DNN" if classifier == "dnn" else "Logistic"
73+
74+
if "alpha" not in data.columns:
75+
data["alpha"] = float(alpha)
76+
77+
data.rename(
78+
columns={"alpha": r"$\alpha$"},
79+
inplace=True
80+
)
81+
82+
# Skip IMF, since it acts on all features anyway
83+
data = data[data.type != "IMF (robust)"]
84+
85+
data["type"] = data["type"].apply(lambda x: technique_names.get(x.replace(" (robust)", "")) if x != "robust_time" else r"\texttt{T-SAR}")
86+
data["type"] = data["type"].apply(lambda x: x+r" ($\epsilon = 0.5$)" if (x != r"\texttt{T-SAR}" and experiment == "only-robust-2") else x)
87+
data["type"] = data["type"].apply(lambda x: x+r" ($\epsilon = 0.05$)" if (x != r"\texttt{T-SAR}" and experiment == "only-robust") else x)
88+
89+
all_dfs.append(data)
90+
91+
data_original = pd.concat(all_dfs)
92+
93+
for t in [0, 10, 30, 50]:
94+
95+
for scm_original in data_original.scm.unique():
96+
97+
data = data_original[(data_original.timestep == t) & (data_original.scm == scm_original)]
98+
99+
# Get groups of actionable features
100+
actionable_groups = groups.get(scm_original)
101+
102+
# Pick only the top-3 methods
103+
mean_of_means = data.groupby(["run_id", "type"])["validity"].mean()
104+
top_3_methods = mean_of_means.groupby("type").mean().sort_values(ascending=False)#.head(3)
105+
top_3_methods = top_3_methods.index
106+
filtered_df = data[data.type.isin(top_3_methods)]
107+
108+
# Consider only the valid elements
109+
filtered_df = filtered_df[filtered_df.validity]
110+
111+
# Step 1: Group by 'id' and check if all 'correct' values for each 'id' are True
112+
#valid_ids = filtered_df.groupby(['run_id', 'user_id'])['validity'].all()
113+
#valid_ids = valid_ids[valid_ids].index
114+
#filtered_df = filtered_df.set_index(['run_id', 'user_id'])
115+
#filtered_df = filtered_df[filtered_df.index.isin(valid_ids)]
116+
117+
# Assert that all actions are correct
118+
# Namely, we do not have to modify non-actionable features.
119+
all_actionable = [item for sublist in actionable_groups for item in sublist]
120+
def _assert(x):
121+
complement_set = np.array([i for i in range(len(x)) if i not in all_actionable])
122+
if len(complement_set) > 0:
123+
assert (x[complement_set] == 0).all()
124+
return x
125+
filtered_df["actions"].apply(lambda x : _assert(x))
126+
127+
128+
# For each intervention set, compute the counts
129+
for interv_set in actionable_groups:
130+
131+
def _func(x):
132+
complement_set = np.array([i for i in range(len(x)) if i not in interv_set])
133+
if len(complement_set) > 0:
134+
return (x[interv_set] != 0).all() and (x[complement_set] == 0).all()
135+
else:
136+
return (x[interv_set] != 0).all()
137+
138+
filtered_df[f"{interv_set}"] = filtered_df["actions"].apply(lambda x : _func(x))
139+
140+
for interv_set in actionable_groups:
141+
given_data = filtered_df.groupby(["type", "run_id"])[f"{interv_set}"].mean()
142+
for (method, run_id), value in zip(given_data.index, given_data.tolist()):
143+
all_data.append(
144+
[t, run_id, method, value, f"{interv_set}"]
145+
)
146+
147+
all_data = pd.DataFrame(all_data, columns=["t", "run_id", "method", "cost", r"$\mathcal{I}$"])
148+
149+
all_data[r"$\mathcal{I}$"] = all_data[r"$\mathcal{I}$"].apply(
150+
lambda x: rename_groups.get("Loan", {}).get(str(x)) if str(x) in rename_groups.get("Loan", {}) else x
151+
)
152+
153+
g = sns.catplot(
154+
data=all_data, x="t", y="cost", hue="method",
155+
col=r"$\mathcal{I}$",
156+
kind="bar", capsize=.4,
157+
height=2.5,
158+
aspect=1.3,
159+
col_wrap=3,
160+
legend=False,
161+
palette=PALETTE
162+
)
163+
164+
axs = g.axes
165+
fig = g.figure
166+
167+
for idx_ax, ax in enumerate(axs):
168+
ax.set_ylabel(
169+
r"\% fraction"
170+
)
171+
ax.set_xlabel(
172+
""#r"Time (t)"
173+
)
174+
ax.grid(axis='y')
175+
ax.set_ylim((0.0, 1.05))
176+
177+
plt.tight_layout()
178+
plt.savefig(f"{scm}_{trend}_actions.pdf", format="pdf", bbox_inches='tight')

0 commit comments

Comments
 (0)