|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import seaborn as sns |
| 5 | + |
| 6 | +import pickle |
| 7 | + |
| 8 | +import argparse |
| 9 | +import os |
| 10 | + |
| 11 | +from src.utils import PALETTE |
| 12 | + |
| 13 | +sns.set(font_scale=1.5, |
| 14 | + style="ticks", |
| 15 | + rc={ |
| 16 | + "text.usetex": True, |
| 17 | + 'text.latex.preamble': r'\usepackage{amsfonts} \usepackage{amsmath} \usepackage{bm}', |
| 18 | + "font.family": "serif", |
| 19 | + }) |
| 20 | + |
| 21 | +technique_names = { |
| 22 | + "CFR": r"\texttt{CAR}", |
| 23 | + "SPR": r"\texttt{SAR}", |
| 24 | + "IMF": r"\texttt{IMF}", |
| 25 | + "robust_time": r"\texttt{T-SAR}" |
| 26 | +} |
| 27 | + |
| 28 | +groups = { |
| 29 | + "Loan": [[5], [6], [5,6]], |
| 30 | + "COMPAS": [[3]], |
| 31 | + "Adult": [[4], [5], [4,5]], |
| 32 | + "Linear ANM": [[0],[1],[2], [0,1], [0,2], [1,2], [0,1,2]], |
| 33 | + "Non-Linear ANM": [[0],[1],[2], [0,1], [0,2], [1,2], [0,1,2]] |
| 34 | +} |
| 35 | + |
| 36 | +rename_groups = { |
| 37 | + "Loan": { |
| 38 | + "[5]": r"$\{income\}$", |
| 39 | + "[6]": r"$\{savings\}$", |
| 40 | + "[5, 6]": r"$\{income, savings\}$", |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | + |
| 45 | +if __name__ == "__main__": |
| 46 | + |
| 47 | + parser = argparse.ArgumentParser() |
| 48 | + parser.add_argument("file", nargs="+", type=str, help="CSV containing the results") |
| 49 | + args = parser.parse_args() |
| 50 | + |
| 51 | + # Where to save all the data |
| 52 | + all_data = [] |
| 53 | + all_dfs = [] |
| 54 | + |
| 55 | + for f in args.file: |
| 56 | + |
| 57 | + data = pickle.load(open(f, "rb")) |
| 58 | + |
| 59 | + experiment, classifier, scm, trend, _, alpha, runs, mc_samples = os.path.basename(f).split("_")[0:8] |
| 60 | + |
| 61 | + if scm == "non-linear": |
| 62 | + data["scm"] = "Non-Linear ANM" |
| 63 | + elif scm == "linear": |
| 64 | + data["scm"] = "Linear ANM" |
| 65 | + elif scm == "adult": |
| 66 | + data["scm"] = "Adult" |
| 67 | + elif scm == "compas": |
| 68 | + data["scm"] = "COMPAS" |
| 69 | + else: |
| 70 | + data["scm"] = "Loan" |
| 71 | + |
| 72 | + data["classifier"] = "DNN" if classifier == "dnn" else "Logistic" |
| 73 | + |
| 74 | + if "alpha" not in data.columns: |
| 75 | + data["alpha"] = float(alpha) |
| 76 | + |
| 77 | + data.rename( |
| 78 | + columns={"alpha": r"$\alpha$"}, |
| 79 | + inplace=True |
| 80 | + ) |
| 81 | + |
| 82 | + # Skip IMF, since it acts on all features anyway |
| 83 | + data = data[data.type != "IMF (robust)"] |
| 84 | + |
| 85 | + data["type"] = data["type"].apply(lambda x: technique_names.get(x.replace(" (robust)", "")) if x != "robust_time" else r"\texttt{T-SAR}") |
| 86 | + data["type"] = data["type"].apply(lambda x: x+r" ($\epsilon = 0.5$)" if (x != r"\texttt{T-SAR}" and experiment == "only-robust-2") else x) |
| 87 | + data["type"] = data["type"].apply(lambda x: x+r" ($\epsilon = 0.05$)" if (x != r"\texttt{T-SAR}" and experiment == "only-robust") else x) |
| 88 | + |
| 89 | + all_dfs.append(data) |
| 90 | + |
| 91 | + data_original = pd.concat(all_dfs) |
| 92 | + |
| 93 | + for t in [0, 10, 30, 50]: |
| 94 | + |
| 95 | + for scm_original in data_original.scm.unique(): |
| 96 | + |
| 97 | + data = data_original[(data_original.timestep == t) & (data_original.scm == scm_original)] |
| 98 | + |
| 99 | + # Get groups of actionable features |
| 100 | + actionable_groups = groups.get(scm_original) |
| 101 | + |
| 102 | + # Pick only the top-3 methods |
| 103 | + mean_of_means = data.groupby(["run_id", "type"])["validity"].mean() |
| 104 | + top_3_methods = mean_of_means.groupby("type").mean().sort_values(ascending=False)#.head(3) |
| 105 | + top_3_methods = top_3_methods.index |
| 106 | + filtered_df = data[data.type.isin(top_3_methods)] |
| 107 | + |
| 108 | + # Consider only the valid elements |
| 109 | + filtered_df = filtered_df[filtered_df.validity] |
| 110 | + |
| 111 | + # Step 1: Group by 'id' and check if all 'correct' values for each 'id' are True |
| 112 | + #valid_ids = filtered_df.groupby(['run_id', 'user_id'])['validity'].all() |
| 113 | + #valid_ids = valid_ids[valid_ids].index |
| 114 | + #filtered_df = filtered_df.set_index(['run_id', 'user_id']) |
| 115 | + #filtered_df = filtered_df[filtered_df.index.isin(valid_ids)] |
| 116 | + |
| 117 | + # Assert that all actions are correct |
| 118 | + # Namely, we do not have to modify non-actionable features. |
| 119 | + all_actionable = [item for sublist in actionable_groups for item in sublist] |
| 120 | + def _assert(x): |
| 121 | + complement_set = np.array([i for i in range(len(x)) if i not in all_actionable]) |
| 122 | + if len(complement_set) > 0: |
| 123 | + assert (x[complement_set] == 0).all() |
| 124 | + return x |
| 125 | + filtered_df["actions"].apply(lambda x : _assert(x)) |
| 126 | + |
| 127 | + |
| 128 | + # For each intervention set, compute the counts |
| 129 | + for interv_set in actionable_groups: |
| 130 | + |
| 131 | + def _func(x): |
| 132 | + complement_set = np.array([i for i in range(len(x)) if i not in interv_set]) |
| 133 | + if len(complement_set) > 0: |
| 134 | + return (x[interv_set] != 0).all() and (x[complement_set] == 0).all() |
| 135 | + else: |
| 136 | + return (x[interv_set] != 0).all() |
| 137 | + |
| 138 | + filtered_df[f"{interv_set}"] = filtered_df["actions"].apply(lambda x : _func(x)) |
| 139 | + |
| 140 | + for interv_set in actionable_groups: |
| 141 | + given_data = filtered_df.groupby(["type", "run_id"])[f"{interv_set}"].mean() |
| 142 | + for (method, run_id), value in zip(given_data.index, given_data.tolist()): |
| 143 | + all_data.append( |
| 144 | + [t, run_id, method, value, f"{interv_set}"] |
| 145 | + ) |
| 146 | + |
| 147 | + all_data = pd.DataFrame(all_data, columns=["t", "run_id", "method", "cost", r"$\mathcal{I}$"]) |
| 148 | + |
| 149 | + all_data[r"$\mathcal{I}$"] = all_data[r"$\mathcal{I}$"].apply( |
| 150 | + lambda x: rename_groups.get("Loan", {}).get(str(x)) if str(x) in rename_groups.get("Loan", {}) else x |
| 151 | + ) |
| 152 | + |
| 153 | + g = sns.catplot( |
| 154 | + data=all_data, x="t", y="cost", hue="method", |
| 155 | + col=r"$\mathcal{I}$", |
| 156 | + kind="bar", capsize=.4, |
| 157 | + height=2.5, |
| 158 | + aspect=1.3, |
| 159 | + col_wrap=3, |
| 160 | + legend=False, |
| 161 | + palette=PALETTE |
| 162 | + ) |
| 163 | + |
| 164 | + axs = g.axes |
| 165 | + fig = g.figure |
| 166 | + |
| 167 | + for idx_ax, ax in enumerate(axs): |
| 168 | + ax.set_ylabel( |
| 169 | + r"\% fraction" |
| 170 | + ) |
| 171 | + ax.set_xlabel( |
| 172 | + ""#r"Time (t)" |
| 173 | + ) |
| 174 | + ax.grid(axis='y') |
| 175 | + ax.set_ylim((0.0, 1.05)) |
| 176 | + |
| 177 | + plt.tight_layout() |
| 178 | + plt.savefig(f"{scm}_{trend}_actions.pdf", format="pdf", bbox_inches='tight') |
0 commit comments