Skip to content

Commit 24ad46e

Browse files
authored
Merge pull request #264 from pymc-labs/kink
Add ability to analyse Regression Kink Analysis Designs
2 parents 7ebab10 + 4386f6b commit 24ad46e

15 files changed

+3404
-15
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ repos:
1010
exclude_types: [svg]
1111
- id: check-yaml
1212
- id: check-added-large-files
13+
args: ['--maxkb=1500']
1314
- repo: https://github.com/charliermarsh/ruff-pre-commit
1415
rev: v0.1.4
1516
hooks:

README.md

+19
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,25 @@ Regression discontinuity designs are used when treatment is applied to units acc
144144

145145
> The data, model fit, and counterfactual are plotted (top). Frequentist analysis shows the causal impact with the blue shaded region, but this is not shown in the Bayesian analysis to avoid a cluttered chart. Instead, the Bayesian analysis shows shaded Bayesian credible regions of the model fits. The Frequentist analysis visualises the point estimate of the causal impact, but the Bayesian analysis also plots the posterior distribution of the regression discontinuity effect (bottom).
146146
147+
### Regression kink designs
148+
149+
Regression discontinuity designs are used when treatment is applied to units according to a cutoff on a running variable, which is typically not time. By looking for the presence of a discontinuity at the precise point of the treatment cutoff then we can make causal claims about the potential impact of the treatment.
150+
151+
| Running variable | Outcome |
152+
|-----------|-----------|
153+
| $x_0$ | $y_0$ |
154+
| $x_1$ | $y_0$ |
155+
| $\ldots$ | $\ldots$ |
156+
| $x_{N-1}$ | $y_{N-1}$ |
157+
| $x_N$ | $y_N$ |
158+
159+
160+
| Frequentist | Bayesian |
161+
|--|--|
162+
| coming soon | ![](docs/source/_static/regression_kink_pymc.svg) |
163+
164+
> The data and model fit. The Bayesian analysis shows the posterior mean with credible intervals (shaded regions). We also report the Bayesian $R^2$ on the data along with the posterior mean and credible intervals of the change in gradient at the kink point.
165+
147166
### Interrupted time series
148167

149168
Interrupted time series analysis is appropriate when you have a time series of observations which undergo treatment at a particular point in time. This kind of analysis has no control group and looks for the presence of a change in the outcome measure at or soon after the treatment time. Multiple predictors can be included.

causalpy/pymc_experiments.py

+222-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
import warnings
15-
from typing import Optional, Union
15+
from typing import Union
1616

1717
import arviz as az
1818
import matplotlib.pyplot as plt
@@ -328,7 +328,7 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
328328
fontsize=LEGEND_FONT_SIZE,
329329
)
330330

331-
return (fig, ax)
331+
return fig, ax
332332

333333
def summary(self) -> None:
334334
"""
@@ -424,7 +424,7 @@ def plot(self, plot_predictors=False, **kwargs):
424424
ax[0].plot(
425425
self.datapost.index, self.post_X, "-", c=[0.8, 0.8, 0.8], zorder=1
426426
)
427-
return (fig, ax)
427+
return fig, ax
428428

429429

430430
class DifferenceInDifferences(ExperimentalDesign):
@@ -794,7 +794,7 @@ def __init__(
794794
model=None,
795795
running_variable_name: str = "x",
796796
epsilon: float = 0.001,
797-
bandwidth: Optional[float] = None,
797+
bandwidth: float = np.inf,
798798
**kwargs,
799799
):
800800
super().__init__(model=model, **kwargs)
@@ -807,7 +807,7 @@ def __init__(
807807
self.bandwidth = bandwidth
808808
self._input_validation()
809809

810-
if self.bandwidth is not None:
810+
if self.bandwidth is not np.inf:
811811
fmin = self.treatment_threshold - self.bandwidth
812812
fmax = self.treatment_threshold + self.bandwidth
813813
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
@@ -836,7 +836,7 @@ def __init__(
836836
self.score = self.model.score(X=self.X, y=self.y)
837837

838838
# get the model predictions of the observed data
839-
if self.bandwidth is not None:
839+
if self.bandwidth is not np.inf:
840840
xi = np.linspace(fmin, fmax, 200)
841841
else:
842842
xi = np.linspace(
@@ -903,7 +903,7 @@ def plot(self):
903903
self.data,
904904
x=self.running_variable_name,
905905
y=self.outcome_variable_name,
906-
c="k", # hue="treated",
906+
c="k",
907907
ax=ax,
908908
)
909909

@@ -939,7 +939,7 @@ def plot(self):
939939
labels=labels,
940940
fontsize=LEGEND_FONT_SIZE,
941941
)
942-
return (fig, ax)
942+
return fig, ax
943943

944944
def summary(self) -> None:
945945
"""
@@ -957,6 +957,220 @@ def summary(self) -> None:
957957
self.print_coefficients()
958958

959959

960+
class RegressionKink(ExperimentalDesign):
961+
"""
962+
A class to analyse sharp regression kink experiments.
963+
964+
:param data:
965+
A pandas dataframe
966+
:param formula:
967+
A statistical model formula
968+
:param kink_point:
969+
A scalar threshold value at which there is a change in the first derivative of
970+
the assignment function
971+
:param model:
972+
A PyMC model
973+
:param running_variable_name:
974+
The name of the predictor variable that the kink_point is based upon
975+
:param epsilon:
976+
A small scalar value which determines how far above and below the kink point to
977+
evaluate the causal impact.
978+
:param bandwidth:
979+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
980+
the model.
981+
"""
982+
983+
def __init__(
984+
self,
985+
data: pd.DataFrame,
986+
formula: str,
987+
kink_point: float,
988+
model=None,
989+
running_variable_name: str = "x",
990+
epsilon: float = 0.001,
991+
bandwidth: float = np.inf,
992+
**kwargs,
993+
):
994+
super().__init__(model=model, **kwargs)
995+
self.expt_type = "Regression Kink"
996+
self.data = data
997+
self.formula = formula
998+
self.running_variable_name = running_variable_name
999+
self.kink_point = kink_point
1000+
self.epsilon = epsilon
1001+
self.bandwidth = bandwidth
1002+
self._input_validation()
1003+
1004+
if self.bandwidth is not np.inf:
1005+
fmin = self.kink_point - self.bandwidth
1006+
fmax = self.kink_point + self.bandwidth
1007+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
1008+
if len(filtered_data) <= 10:
1009+
warnings.warn(
1010+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
1011+
UserWarning,
1012+
)
1013+
y, X = dmatrices(formula, filtered_data)
1014+
else:
1015+
y, X = dmatrices(formula, self.data)
1016+
1017+
self._y_design_info = y.design_info
1018+
self._x_design_info = X.design_info
1019+
self.labels = X.design_info.column_names
1020+
self.y, self.X = np.asarray(y), np.asarray(X)
1021+
self.outcome_variable_name = y.design_info.column_names[0]
1022+
1023+
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
1024+
self.model.fit(X=self.X, y=self.y, coords=COORDS)
1025+
1026+
# score the goodness of fit to all data
1027+
self.score = self.model.score(X=self.X, y=self.y)
1028+
1029+
# get the model predictions of the observed data
1030+
if self.bandwidth is not np.inf:
1031+
xi = np.linspace(fmin, fmax, 200)
1032+
else:
1033+
xi = np.linspace(
1034+
np.min(self.data[self.running_variable_name]),
1035+
np.max(self.data[self.running_variable_name]),
1036+
200,
1037+
)
1038+
self.x_pred = pd.DataFrame(
1039+
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
1040+
)
1041+
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
1042+
self.pred = self.model.predict(X=np.asarray(new_x))
1043+
1044+
# evaluate gradient change around kink point
1045+
mu_kink_left, mu_kink, mu_kink_right = self._probe_kink_point()
1046+
self.gradient_change = self._eval_gradient_change(
1047+
mu_kink_left, mu_kink, mu_kink_right, epsilon
1048+
)
1049+
1050+
@staticmethod
1051+
def _eval_gradient_change(mu_kink_left, mu_kink, mu_kink_right, epsilon):
1052+
"""Evaluate the gradient change at the kink point.
1053+
It works by evaluating the model below the kink point, at the kink point,
1054+
and above the kink point.
1055+
This is a static method for ease of testing.
1056+
"""
1057+
gradient_left = (mu_kink - mu_kink_left) / epsilon
1058+
gradient_right = (mu_kink_right - mu_kink) / epsilon
1059+
gradient_change = gradient_right - gradient_left
1060+
return gradient_change
1061+
1062+
def _probe_kink_point(self):
1063+
# Create a dataframe to evaluate predicted outcome at the kink point and either
1064+
# side
1065+
x_predict = pd.DataFrame(
1066+
{
1067+
self.running_variable_name: np.array(
1068+
[
1069+
self.kink_point - self.epsilon,
1070+
self.kink_point,
1071+
self.kink_point + self.epsilon,
1072+
]
1073+
),
1074+
"treated": np.array([0, 1, 1]),
1075+
}
1076+
)
1077+
(new_x,) = build_design_matrices([self._x_design_info], x_predict)
1078+
predicted = self.model.predict(X=np.asarray(new_x))
1079+
# extract predicted mu values
1080+
mu_kink_left = predicted["posterior_predictive"].sel(obs_ind=0)["mu"]
1081+
mu_kink = predicted["posterior_predictive"].sel(obs_ind=1)["mu"]
1082+
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
1083+
return mu_kink_left, mu_kink, mu_kink_right
1084+
1085+
def _input_validation(self):
1086+
"""Validate the input data and model formula for correctness"""
1087+
if "treated" not in self.formula:
1088+
raise FormulaException(
1089+
"A predictor called `treated` should be in the formula"
1090+
)
1091+
1092+
if _is_variable_dummy_coded(self.data["treated"]) is False:
1093+
raise DataException(
1094+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
1095+
)
1096+
1097+
if self.bandwidth <= 0:
1098+
raise ValueError("The bandwidth must be greater than zero.")
1099+
1100+
if self.epsilon <= 0:
1101+
raise ValueError("Epsilon must be greater than zero.")
1102+
1103+
def _is_treated(self, x):
1104+
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
1105+
return np.greater_equal(x, self.kink_point)
1106+
1107+
def plot(self):
1108+
"""
1109+
Plot the results
1110+
"""
1111+
fig, ax = plt.subplots()
1112+
# Plot raw data
1113+
sns.scatterplot(
1114+
self.data,
1115+
x=self.running_variable_name,
1116+
y=self.outcome_variable_name,
1117+
c="k", # hue="treated",
1118+
ax=ax,
1119+
)
1120+
1121+
# Plot model fit to data
1122+
h_line, h_patch = plot_xY(
1123+
self.x_pred[self.running_variable_name],
1124+
self.pred["posterior_predictive"].mu,
1125+
ax=ax,
1126+
plot_hdi_kwargs={"color": "C1"},
1127+
)
1128+
handles = [(h_line, h_patch)]
1129+
labels = ["Posterior mean"]
1130+
1131+
# create strings to compose title
1132+
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
1133+
r2 = f"Bayesian $R^2$ on all data = {title_info}"
1134+
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
1135+
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1136+
grad_change = f"""
1137+
Change in gradient = {self.gradient_change.mean():.2f},
1138+
"""
1139+
ax.set(title=r2 + "\n" + grad_change + ci)
1140+
# Intervention line
1141+
ax.axvline(
1142+
x=self.kink_point,
1143+
ls="-",
1144+
lw=3,
1145+
color="r",
1146+
label="treatment threshold",
1147+
)
1148+
ax.legend(
1149+
handles=(h_tuple for h_tuple in handles),
1150+
labels=labels,
1151+
fontsize=LEGEND_FONT_SIZE,
1152+
)
1153+
return fig, ax
1154+
1155+
def summary(self) -> None:
1156+
"""
1157+
Print text output summarising the results
1158+
"""
1159+
1160+
print(
1161+
f"""
1162+
{self.expt_type:=^80}
1163+
Formula: {self.formula}
1164+
Running variable: {self.running_variable_name}
1165+
Kink point on running variable: {self.kink_point}
1166+
1167+
Results:
1168+
Change in slope at kink point = {self.gradient_change.mean():.2f}
1169+
"""
1170+
)
1171+
self.print_coefficients()
1172+
1173+
9601174
class PrePostNEGD(ExperimentalDesign):
9611175
"""
9621176
A class to analyse data from pretest/posttest designs

0 commit comments

Comments
 (0)