@@ -176,13 +176,21 @@ def __init__(
176
176
data : pd .DataFrame ,
177
177
formula : str ,
178
178
time_variable_name : str ,
179
+ group_variable_name : str ,
180
+ treated : str ,
181
+ untreated : str ,
179
182
model = None ,
180
183
** kwargs ,
181
184
):
182
185
super ().__init__ (model = model , ** kwargs )
183
186
self .data = data
184
187
self .formula = formula
185
188
self .time_variable_name = time_variable_name
189
+ self .group_variable_name = group_variable_name
190
+ self .treated = treated # level of the group_variable_name that was treated
191
+ self .untreated = (
192
+ untreated # level of the group_variable_name that was untreated
193
+ )
186
194
y , X = dmatrices (formula , self .data )
187
195
self ._y_design_info = y .design_info
188
196
self ._x_design_info = X .design_info
@@ -194,32 +202,66 @@ def __init__(
194
202
self .model .fit (X = self .X , y = self .y )
195
203
196
204
# predicted outcome for control group
197
- self .x_pred_control = pd .DataFrame (
198
- {"group" : [0 , 0 ], "t" : [0.0 , 1.0 ], "post_treatment" : [0 , 0 ]}
205
+ self .x_pred_control = (
206
+ self .data
207
+ # just the untreated group
208
+ .query (f"{ self .group_variable_name } == @self.untreated" )
209
+ # drop the outcome variable
210
+ .drop (self .outcome_variable_name , axis = 1 )
211
+ # We may have multiple units per time point, we only want one time point
212
+ .groupby (self .time_variable_name )
213
+ .first ()
214
+ .reset_index ()
199
215
)
200
216
assert not self .x_pred_control .empty
201
217
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_control )
202
218
self .y_pred_control = self .model .predict (np .asarray (new_x ))
203
219
204
220
# predicted outcome for treatment group
205
- self .x_pred_treatment = pd .DataFrame (
206
- {"group" : [1 , 1 ], "t" : [0.0 , 1.0 ], "post_treatment" : [0 , 1 ]}
221
+ self .x_pred_treatment = (
222
+ self .data
223
+ # just the treated group
224
+ .query (f"{ self .group_variable_name } == @self.treated" )
225
+ # drop the outcome variable
226
+ .drop (self .outcome_variable_name , axis = 1 )
227
+ # We may have multiple units per time point, we only want one time point
228
+ .groupby (self .time_variable_name )
229
+ .first ()
230
+ .reset_index ()
207
231
)
208
232
assert not self .x_pred_treatment .empty
209
233
(new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
210
234
self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
211
235
212
- # predicted outcome for counterfactual
213
- self .x_pred_counterfactual = pd .DataFrame (
214
- {"group" : [1 ], "t" : [1.0 ], "post_treatment" : [0 ]}
236
+ # predicted outcome for counterfactual. This is given by removing the influence
237
+ # of the interaction term between the group and the post_treatment variable
238
+ self .x_pred_counterfactual = (
239
+ self .data
240
+ # just the treated group
241
+ .query (f"{ self .group_variable_name } == @self.treated" )
242
+ # just the treatment period(s)
243
+ .query ("post_treatment == True" )
244
+ # drop the outcome variable
245
+ .drop (self .outcome_variable_name , axis = 1 )
246
+ # We may have multiple units per time point, we only want one time point
247
+ .groupby (self .time_variable_name )
248
+ .first ()
249
+ .reset_index ()
215
250
)
216
251
assert not self .x_pred_counterfactual .empty
217
252
(new_x ,) = build_design_matrices (
218
- [self ._x_design_info ], self .x_pred_counterfactual
253
+ [self ._x_design_info ], self .x_pred_counterfactual , return_type = "dataframe"
219
254
)
255
+ # INTERVENTION: set the interaction term between the group and the
256
+ # post_treatment variable to zero. This is the counterfactual.
257
+ for i , label in enumerate (self .labels ):
258
+ if "post_treatment" in label and self .group_variable_name in label :
259
+ new_x .iloc [:, i ] = 0
220
260
self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
221
261
222
262
# calculate causal impact
263
+ # This is the coefficient on the interaction term
264
+ # TODO: THIS IS NOT YET CORRECT
223
265
self .causal_impact = self .y_pred_treatment [1 ] - self .y_pred_counterfactual [0 ]
224
266
225
267
def plot (self ):
0 commit comments