Skip to content

Commit 73eb32f

Browse files
authored
Merge pull request #147 from pymc-labs/DiD-fixes
FIX: Corrects error discovered in Difference in Differences code + updates skl code
2 parents 55247c9 + 37c7f1a commit 73eb32f

7 files changed

+184
-115
lines changed

causalpy/pymc_experiments.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def __init__(
340340
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
341341
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
342342

343-
# predicted outcome for counterfactual
343+
# predicted outcome for counterfactual. This is given by removing the influence
344+
# of the interaction term between the group and the post_treatment variable
344345
self.x_pred_counterfactual = (
345346
self.data
346347
# just the treated group
@@ -349,24 +350,28 @@ def __init__(
349350
.query("post_treatment == True")
350351
# drop the outcome variable
351352
.drop(self.outcome_variable_name, axis=1)
352-
# DO AN INTERVENTION. Set the post_treatment variable to False
353-
.assign(post_treatment=False)
354353
# We may have multiple units per time point, we only want one time point
355354
.groupby(self.time_variable_name)
356355
.first()
357356
.reset_index()
358357
)
359358
assert not self.x_pred_counterfactual.empty
360359
(new_x,) = build_design_matrices(
361-
[self._x_design_info], self.x_pred_counterfactual
360+
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
362361
)
362+
# INTERVENTION: set the interaction term between the group and the
363+
# post_treatment variable to zero. This is the counterfactual.
364+
for i, label in enumerate(self.labels):
365+
if "post_treatment" in label and self.group_variable_name in label:
366+
new_x.iloc[:, i] = 0
363367
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
364368

365-
# calculate causal impact
366-
self.causal_impact = (
367-
self.y_pred_treatment["posterior_predictive"].mu.isel({"obs_ind": 1})
368-
- self.y_pred_counterfactual["posterior_predictive"].mu.squeeze()
369-
)
369+
# calculate causal impact.
370+
# This is the coefficient on the interaction term
371+
coeff_names = self.idata.posterior.coords["coeffs"].data
372+
for i, label in enumerate(coeff_names):
373+
if "post_treatment" in label and self.group_variable_name in label:
374+
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
370375

371376
def plot(self):
372377
"""Plot the results.

causalpy/skl_experiments.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,21 @@ def __init__(
176176
data: pd.DataFrame,
177177
formula: str,
178178
time_variable_name: str,
179+
group_variable_name: str,
180+
treated: str,
181+
untreated: str,
179182
model=None,
180183
**kwargs,
181184
):
182185
super().__init__(model=model, **kwargs)
183186
self.data = data
184187
self.formula = formula
185188
self.time_variable_name = time_variable_name
189+
self.group_variable_name = group_variable_name
190+
self.treated = treated # level of the group_variable_name that was treated
191+
self.untreated = (
192+
untreated # level of the group_variable_name that was untreated
193+
)
186194
y, X = dmatrices(formula, self.data)
187195
self._y_design_info = y.design_info
188196
self._x_design_info = X.design_info
@@ -194,32 +202,66 @@ def __init__(
194202
self.model.fit(X=self.X, y=self.y)
195203

196204
# predicted outcome for control group
197-
self.x_pred_control = pd.DataFrame(
198-
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
205+
self.x_pred_control = (
206+
self.data
207+
# just the untreated group
208+
.query(f"{self.group_variable_name} == @self.untreated")
209+
# drop the outcome variable
210+
.drop(self.outcome_variable_name, axis=1)
211+
# We may have multiple units per time point, we only want one time point
212+
.groupby(self.time_variable_name)
213+
.first()
214+
.reset_index()
199215
)
200216
assert not self.x_pred_control.empty
201217
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
202218
self.y_pred_control = self.model.predict(np.asarray(new_x))
203219

204220
# predicted outcome for treatment group
205-
self.x_pred_treatment = pd.DataFrame(
206-
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
221+
self.x_pred_treatment = (
222+
self.data
223+
# just the treated group
224+
.query(f"{self.group_variable_name} == @self.treated")
225+
# drop the outcome variable
226+
.drop(self.outcome_variable_name, axis=1)
227+
# We may have multiple units per time point, we only want one time point
228+
.groupby(self.time_variable_name)
229+
.first()
230+
.reset_index()
207231
)
208232
assert not self.x_pred_treatment.empty
209233
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
210234
self.y_pred_treatment = self.model.predict(np.asarray(new_x))
211235

212-
# predicted outcome for counterfactual
213-
self.x_pred_counterfactual = pd.DataFrame(
214-
{"group": [1], "t": [1.0], "post_treatment": [0]}
236+
# predicted outcome for counterfactual. This is given by removing the influence
237+
# of the interaction term between the group and the post_treatment variable
238+
self.x_pred_counterfactual = (
239+
self.data
240+
# just the treated group
241+
.query(f"{self.group_variable_name} == @self.treated")
242+
# just the treatment period(s)
243+
.query("post_treatment == True")
244+
# drop the outcome variable
245+
.drop(self.outcome_variable_name, axis=1)
246+
# We may have multiple units per time point, we only want one time point
247+
.groupby(self.time_variable_name)
248+
.first()
249+
.reset_index()
215250
)
216251
assert not self.x_pred_counterfactual.empty
217252
(new_x,) = build_design_matrices(
218-
[self._x_design_info], self.x_pred_counterfactual
253+
[self._x_design_info], self.x_pred_counterfactual, return_type="dataframe"
219254
)
255+
# INTERVENTION: set the interaction term between the group and the
256+
# post_treatment variable to zero. This is the counterfactual.
257+
for i, label in enumerate(self.labels):
258+
if "post_treatment" in label and self.group_variable_name in label:
259+
new_x.iloc[:, i] = 0
220260
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
221261

222262
# calculate causal impact
263+
# This is the coefficient on the interaction term
264+
# TODO: THIS IS NOT YET CORRECT
223265
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
224266

225267
def plot(self):

causalpy/tests/test_integration_pymc_examples.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_did():
1111
df = cp.load_data("did")
1212
result = cp.pymc_experiments.DifferenceInDifferences(
1313
df,
14-
formula="y ~ 1 + group + t + group:post_treatment",
14+
formula="y ~ 1 + group*post_treatment",
1515
time_variable_name="t",
1616
group_variable_name="group",
1717
treated=1,
@@ -37,6 +37,10 @@ def test_did_banks_simple():
3737
.groupby("year")
3838
.median()
3939
)
40+
# SET TREATMENT TIME TO ZERO =========
41+
df.index = df.index - treatment_time
42+
treatment_time = 0
43+
# ====================================
4044
df.reset_index(level=0, inplace=True)
4145
df_long = pd.melt(
4246
df,
@@ -45,16 +49,18 @@ def test_did_banks_simple():
4549
var_name="district",
4650
value_name="bib",
4751
).sort_values("year")
48-
df_long["district"] = df_long["district"].astype("category")
4952
df_long["unit"] = df_long["district"]
5053
df_long["post_treatment"] = df_long.year >= treatment_time
54+
df_long = df_long.replace({"district": {"Sixth District": 1, "Eighth District": 0}})
55+
5156
result = cp.pymc_experiments.DifferenceInDifferences(
52-
df_long[df_long.year.isin([1930, 1931])],
53-
formula="bib ~ 1 + district + year + district:post_treatment",
57+
# df_long[df_long.year.isin([1930, 1931])],
58+
df_long[df_long.year.isin([-0.5, 0.5])],
59+
formula="bib ~ 1 + district * post_treatment",
5460
time_variable_name="year",
5561
group_variable_name="district",
56-
treated="Sixth District",
57-
untreated="Eighth District",
62+
treated=1,
63+
untreated=0,
5864
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5965
)
6066
assert isinstance(df, pd.DataFrame)
@@ -73,6 +79,10 @@ def test_did_banks_multi():
7379
.groupby("year")
7480
.median()
7581
)
82+
# SET TREATMENT TIME TO ZERO =========
83+
df.index = df.index - treatment_time
84+
treatment_time = 0
85+
# ====================================
7686
df.reset_index(level=0, inplace=True)
7787
df_long = pd.melt(
7888
df,
@@ -81,16 +91,17 @@ def test_did_banks_multi():
8191
var_name="district",
8292
value_name="bib",
8393
).sort_values("year")
84-
df_long["district"] = df_long["district"].astype("category")
8594
df_long["unit"] = df_long["district"]
8695
df_long["post_treatment"] = df_long.year >= treatment_time
96+
df_long = df_long.replace({"district": {"Sixth District": 1, "Eighth District": 0}})
97+
8798
result = cp.pymc_experiments.DifferenceInDifferences(
8899
df_long,
89-
formula="bib ~ 1 + district + year + district:post_treatment",
100+
formula="bib ~ 1 + year + district + post_treatment + district:post_treatment",
90101
time_variable_name="year",
91102
group_variable_name="district",
92-
treated="Sixth District",
93-
untreated="Eighth District",
103+
treated=1,
104+
untreated=0,
94105
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
95106
)
96107
assert isinstance(df, pd.DataFrame)

causalpy/tests/test_integration_skl_examples.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def test_did():
1212
data = cp.load_data("did")
1313
result = cp.skl_experiments.DifferenceInDifferences(
1414
data,
15-
formula="y ~ 1 + group + t + group:post_treatment",
15+
formula="y ~ 1 + group*post_treatment",
1616
time_variable_name="t",
17+
group_variable_name="group",
18+
treated=1,
19+
untreated=0,
1720
model=LinearRegression(),
1821
)
1922
assert isinstance(data, pd.DataFrame)

docs/notebooks/did_pymc.ipynb

+12-17
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)