Skip to content

Commit 37c7f1a

Browse files
committed
fix up skl DiD code + rerun DiD notebooks + update integration tests
1 parent e882ad2 commit 37c7f1a

File tree

5 files changed

+130
-139
lines changed

5 files changed

+130
-139
lines changed

causalpy/skl_experiments.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,21 @@ def __init__(
176176
data: pd.DataFrame,
177177
formula: str,
178178
time_variable_name: str,
179+
group_variable_name: str,
180+
treated: str,
181+
untreated: str,
179182
model=None,
180183
**kwargs,
181184
):
182185
super().__init__(model=model, **kwargs)
183186
self.data = data
184187
self.formula = formula
185188
self.time_variable_name = time_variable_name
189+
self.group_variable_name = group_variable_name
190+
self.treated = treated # level of the group_variable_name that was treated
191+
self.untreated = (
192+
untreated # level of the group_variable_name that was untreated
193+
)
186194
y, X = dmatrices(formula, self.data)
187195
self._y_design_info = y.design_info
188196
self._x_design_info = X.design_info
@@ -194,32 +202,66 @@ def __init__(
194202
self.model.fit(X=self.X, y=self.y)
195203

196204
# predicted outcome for control group
197-
self.x_pred_control = pd.DataFrame(
198-
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
205+
self.x_pred_control = (
206+
self.data
207+
# just the untreated group
208+
.query(f"{self.group_variable_name} == @self.untreated")
209+
# drop the outcome variable
210+
.drop(self.outcome_variable_name, axis=1)
211+
# We may have multiple units per time point, we only want one time point
212+
.groupby(self.time_variable_name)
213+
.first()
214+
.reset_index()
199215
)
200216
assert not self.x_pred_control.empty
201217
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
202218
self.y_pred_control = self.model.predict(np.asarray(new_x))
203219

204220
# predicted outcome for treatment group
205-
self.x_pred_treatment = pd.DataFrame(
206-
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
221+
self.x_pred_treatment = (
222+
self.data
223+
# just the treated group
224+
.query(f"{self.group_variable_name} == @self.treated")
225+
# drop the outcome variable
226+
.drop(self.outcome_variable_name, axis=1)
227+
# We may have multiple units per time point, we only want one time point
228+
.groupby(self.time_variable_name)
229+
.first()
230+
.reset_index()
207231
)
208232
assert not self.x_pred_treatment.empty
209233
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
210234
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
211235

212-
# predicted outcome for counterfactual
213-
self.x_pred_counterfactual = pd.DataFrame(
214-
{"group": [1], "t": [1.0], "post_treatment": [0]}
236+
# predicted outcome for counterfactual. This is given by removing the influence
237+
# of the interaction term between the group and the post_treatment variable
238+
self.x_pred_counterfactual = (
239+
self.data
240+
# just the treated group
241+
.query(f"{self.group_variable_name} == @self.treated")
242+
# just the treatment period(s)
243+
.query("post_treatment == True")
244+
# drop the outcome variable
245+
.drop(self.outcome_variable_name, axis=1)
246+
# We may have multiple units per time point, we only want one time point
247+
.groupby(self.time_variable_name)
248+
.first()
249+
.reset_index()
215250
)
216251
assert not self.x_pred_counterfactual.empty
217252
(new_x,) = build_design_matrices(
218-
[self._x_design_info], self.x_pred_counterfactual
253+
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
219254
)
255+
# INTERVENTION: set the interaction term between the group and the
256+
# post_treatment variable to zero. This is the counterfactual.
257+
for i, label in enumerate(self.labels):
258+
if "post_treatment" in label and self.group_variable_name in label:
259+
new_x.iloc[:, i] = 0
220260
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
221261

222262
# calculate causal impact
263+
# This is the coefficient on the interaction term
264+
# TODO: THIS IS NOT YET CORRECT
223265
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
224266

225267
def plot(self):

causalpy/tests/test_integration_pymc_examples.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_did():
1111
df = cp.load_data("did")
1212
result = cp.pymc_experiments.DifferenceInDifferences(
1313
df,
14-
formula="y ~ 1 + group + t + group:post_treatment",
14+
formula="y ~ 1 + group*post_treatment",
1515
time_variable_name="t",
1616
group_variable_name="group",
1717
treated=1,
@@ -37,6 +37,10 @@ def test_did_banks_simple():
3737
.groupby("year")
3838
.median()
3939
)
40+
# SET TREATMENT TIME TO ZERO =========
41+
df.index = df.index - treatment_time
42+
treatment_time = 0
43+
# ====================================
4044
df.reset_index(level=0, inplace=True)
4145
df_long = pd.melt(
4246
df,
@@ -45,16 +49,18 @@ def test_did_banks_simple():
4549
var_name="district",
4650
value_name="bib",
4751
).sort_values("year")
48-
df_long["district"] = df_long["district"].astype("category")
4952
df_long["unit"] = df_long["district"]
5053
df_long["post_treatment"] = df_long.year >= treatment_time
54+
df_long = df_long.replace({"district": {"Sixth District": 1, "Eighth District": 0}})
55+
5156
result = cp.pymc_experiments.DifferenceInDifferences(
52-
df_long[df_long.year.isin([1930, 1931])],
53-
formula="bib ~ 1 + district + year + district:post_treatment",
57+
# df_long[df_long.year.isin([1930, 1931])],
58+
df_long[df_long.year.isin([-0.5, 0.5])],
59+
formula="bib ~ 1 + district * post_treatment",
5460
time_variable_name="year",
5561
group_variable_name="district",
56-
treated="Sixth District",
57-
untreated="Eighth District",
62+
treated=1,
63+
untreated=0,
5864
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5965
)
6066
assert isinstance(df, pd.DataFrame)
@@ -73,6 +79,10 @@ def test_did_banks_multi():
7379
.groupby("year")
7480
.median()
7581
)
82+
# SET TREATMENT TIME TO ZERO =========
83+
df.index = df.index - treatment_time
84+
treatment_time = 0
85+
# ====================================
7686
df.reset_index(level=0, inplace=True)
7787
df_long = pd.melt(
7888
df,
@@ -81,16 +91,17 @@ def test_did_banks_multi():
8191
var_name="district",
8292
value_name="bib",
8393
).sort_values("year")
84-
df_long["district"] = df_long["district"].astype("category")
8594
df_long["unit"] = df_long["district"]
8695
df_long["post_treatment"] = df_long.year >= treatment_time
96+
df_long = df_long.replace({"district": {"Sixth District": 1, "Eighth District": 0}})
97+
8798
result = cp.pymc_experiments.DifferenceInDifferences(
8899
df_long,
89-
formula="bib ~ 1 + district + year + district:post_treatment",
100+
formula="bib ~ 1 + year + district + post_treatment + district:post_treatment",
90101
time_variable_name="year",
91102
group_variable_name="district",
92-
treated="Sixth District",
93-
untreated="Eighth District",
103+
treated=1,
104+
untreated=0,
94105
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
95106
)
96107
assert isinstance(df, pd.DataFrame)

causalpy/tests/test_integration_skl_examples.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def test_did():
1212
data = cp.load_data("did")
1313
result = cp.skl_experiments.DifferenceInDifferences(
1414
data,
15-
formula="y ~ 1 + group + t + group:post_treatment",
15+
formula="y ~ 1 + group*post_treatment",
1616
time_variable_name="t",
17+
group_variable_name="group",
18+
treated=1,
19+
untreated=0,
1720
model=LinearRegression(),
1821
)
1922
assert isinstance(data, pd.DataFrame)

docs/notebooks/did_pymc_banks.ipynb

+27-114
Large diffs are not rendered by default.

docs/notebooks/did_skl.ipynb

+28-6
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,55 @@
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
28-
"data = cp.load_data(\"did\")"
28+
"%load_ext autoreload\n",
29+
"%autoreload 2"
2930
]
3031
},
3132
{
3233
"cell_type": "code",
3334
"execution_count": 3,
3435
"metadata": {},
3536
"outputs": [],
37+
"source": [
38+
"data = cp.load_data(\"did\")"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": 4,
44+
"metadata": {},
45+
"outputs": [
46+
{
47+
"name": "stderr",
48+
"output_type": "stream",
49+
"text": [
50+
"/Users/benjamv/git/CausalPy/causalpy/skl_experiments.py:259: FutureWarning: In a future version, `df.iloc[:, i] = newvals` will attempt to set the values inplace instead of always setting a new array. To retain the old behavior, use either `df[df.columns[i]] = newvals` or, if columns are non-unique, `df.isetitem(i, newvals)`\n",
51+
" new_x.iloc[:, i] = 0\n"
52+
]
53+
}
54+
],
3655
"source": [
3756
"result = cp.skl_experiments.DifferenceInDifferences(\n",
3857
" data,\n",
39-
" formula=\"y ~ 1 + group + t + group:post_treatment\",\n",
58+
" formula=\"y ~ 1 + group*post_treatment\",\n",
4059
" time_variable_name=\"t\",\n",
60+
" group_variable_name=\"group\",\n",
61+
" treated=1,\n",
62+
" untreated=0,\n",
4163
" model=LinearRegression(),\n",
4264
")"
4365
]
4466
},
4567
{
4668
"cell_type": "code",
47-
"execution_count": 4,
69+
"execution_count": 5,
4870
"metadata": {},
4971
"outputs": [
5072
{
5173
"name": "stderr",
5274
"output_type": "stream",
5375
"text": [
54-
"/Users/benjamv/opt/mambaforge/envs/CausalPy/lib/python3.10/site-packages/numpy/core/_methods.py:164: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
76+
"/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.10/site-packages/numpy/core/_methods.py:164: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
5577
" arr = asanyarray(a)\n"
5678
]
5779
},
@@ -94,12 +116,12 @@
94116
"name": "python",
95117
"nbconvert_exporter": "python",
96118
"pygments_lexer": "ipython3",
97-
"version": "3.10.6"
119+
"version": "3.10.8"
98120
},
99121
"orig_nbformat": 4,
100122
"vscode": {
101123
"interpreter": {
102-
"hash": "02f5385db19eab57520277c5168790c7855381ee953bdbb5c89c321e1f17586e"
124+
"hash": "46d31859cc45aa26a1223a391e7cf3023d69984b498bed11e66c690302b7e251"
103125
}
104126
}
105127
},

0 commit comments

Comments
 (0)