Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
fe6375e
static panel model, data, and util
JulianDiefenbacher Jun 23, 2025
2031923
update plpr and dataset
JulianDiefenbacher Jun 26, 2025
5f6a374
add basic simulations
JulianDiefenbacher Jun 26, 2025
215a5e5
add plpr __str__, and checks
JulianDiefenbacher Jul 10, 2025
4189671
update example_sim
JulianDiefenbacher Jul 10, 2025
2c7e238
add model descriptions
JulianDiefenbacher Oct 9, 2025
7708f9b
fix typo
JulianDiefenbacher Oct 9, 2025
d5bf9e9
fix notation consistency
JulianDiefenbacher Oct 9, 2025
d8c3039
update description numbering
JulianDiefenbacher Oct 9, 2025
375325f
Merge branch 'main' into j-static-panel
JulianDiefenbacher Oct 27, 2025
753f68a
update from ClusterData to base Data class
JulianDiefenbacher Oct 27, 2025
058da4e
add static_panel flag in PanelData
JulianDiefenbacher Oct 27, 2025
85da2eb
add static_panel property
JulianDiefenbacher Nov 4, 2025
08cfd8c
add static_panel property and update tests for panel data handling
SvenKlaassen Nov 4, 2025
1ecab66
update plpr model, include data transformation
JulianDiefenbacher Nov 6, 2025
353e148
Merge branch 'main' into j-static-panel
JulianDiefenbacher Nov 6, 2025
5dc1a44
refactor: simplify string representation and add additional info meth…
SvenKlaassen Nov 7, 2025
fdc4330
correct score info string newline spacing
JulianDiefenbacher Nov 10, 2025
bccc81c
data transform update, add transform_col_names property
JulianDiefenbacher Nov 10, 2025
eb68557
add transformed data arrays for nuisance estimation
JulianDiefenbacher Nov 10, 2025
c404f06
add _initialize_fd_model because of n_obs and smpls issue
JulianDiefenbacher Nov 11, 2025
3ccfb34
clearer TODO description
JulianDiefenbacher Nov 11, 2025
ab9f83f
move data transformation before init
JulianDiefenbacher Nov 13, 2025
f27ebeb
update logic for cre_normal approach in estimation and tuning
JulianDiefenbacher Nov 13, 2025
011f8cb
add simulation replication
JulianDiefenbacher Nov 24, 2025
3f8eef6
update PLPR model
JulianDiefenbacher Nov 24, 2025
677c5a6
allow binary treatment for PLPR
JulianDiefenbacher Nov 24, 2025
7c7bd43
update plpr dgp function
JulianDiefenbacher Nov 27, 2025
ecce2b5
update make_plpr use
JulianDiefenbacher Nov 27, 2025
122e97d
Merge branch 'main' into j-static-panel
JulianDiefenbacher Nov 27, 2025
453cbf6
add basic plpr tests
JulianDiefenbacher Nov 30, 2025
86c947a
remove notebooks
JulianDiefenbacher Dec 1, 2025
7d01cd1
fix id var issue
JulianDiefenbacher Dec 1, 2025
7c24a99
complete plpr dataset tests
JulianDiefenbacher Dec 2, 2025
fb56a27
fix formatting
JulianDiefenbacher Dec 2, 2025
b15571c
add external pred test, complete model default test
JulianDiefenbacher Dec 3, 2025
7e72ad2
remove callable score option for plpr
JulianDiefenbacher Dec 4, 2025
630e736
Merge branch 'main' into j-static-panel
JulianDiefenbacher Dec 5, 2025
85eee84
Merge branch 'main' into j-static-panel
SvenKlaassen Dec 5, 2025
47b929e
add _nuisance_tuning_optuna to plpr with tests
SvenKlaassen Dec 5, 2025
e4fc658
refactor test_doubleml_plr_optuna_tune to use score parameter and imp…
SvenKlaassen Dec 5, 2025
dcf9e71
increase sample size in test_doubleml_lplr_optuna_tune to enable robu…
SvenKlaassen Dec 5, 2025
f5c6c50
refactor _nuisance_est method signature to reorder parameters for cla…
SvenKlaassen Dec 5, 2025
9e65836
Merge branch 'j-static-panel' of https://github.com/DoubleML/doubleml…
JulianDiefenbacher Dec 5, 2025
fcc7195
change dml import
JulianDiefenbacher Dec 5, 2025
19ef9b5
check plpr test dml import
JulianDiefenbacher Dec 5, 2025
85ac102
change dml import to avoid codacy fails
JulianDiefenbacher Dec 5, 2025
4e63915
adjust summary new line and additioanl info header spacing
JulianDiefenbacher Dec 8, 2025
bb1167c
add x_dim check to plpr data
JulianDiefenbacher Dec 8, 2025
0fe1c1b
add comparison of standard errors
JulianDiefenbacher Dec 8, 2025
d32f339
add test_plpr_binary_treatment.py
JulianDiefenbacher Dec 8, 2025
66ca454
add exceptions tests for non-panel and non-static panel data
JulianDiefenbacher Dec 8, 2025
eb75e1d
add plpr return types test and enhance return type checks for cluster…
JulianDiefenbacher Dec 8, 2025
ef6a3c9
add test for plpr data transformations
JulianDiefenbacher Dec 8, 2025
1df9cf2
improve documentation and improve data handling in plpr class
JulianDiefenbacher Dec 8, 2025
975b425
add validation for time variable data type in DoubleMLPanelData
JulianDiefenbacher Dec 9, 2025
6fd77f1
fix: correct test description for overlapping variables in DoubleMLPa…
JulianDiefenbacher Dec 9, 2025
e373473
add time_type parameter to make_plpr_CP2025 and enhance validation in…
JulianDiefenbacher Dec 9, 2025
7652ab6
update docstring
SvenKlaassen Dec 10, 2025
8094ed1
split variance for external predictions
SvenKlaassen Dec 10, 2025
5a09a67
fix: adjust indentation for math blocks in docstring of make_plpr_CP2025
SvenKlaassen Dec 10, 2025
4a0cc6a
fix: update mathematical notation in make_plpr_CP2025 to use g_0 inst…
SvenKlaassen Dec 10, 2025
6995bde
fix: correct mathematical notation in make_plpr_CP2025 to use '=' ins…
SvenKlaassen Dec 10, 2025
a3362dc
Merge branch 'main' into j-static-panel
SvenKlaassen Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
data transform update, add transform_col_names property
  • Loading branch information
JulianDiefenbacher committed Nov 10, 2025
commit bccc81c43eccded590743317216fbcb60cbb53a5
4 changes: 2 additions & 2 deletions doubleml/plm/datasets/dgp_static_panel_CP2025.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def alpha_i(x_it, d_it, a_i, num_n, num_t):

x_cols = [f'x{i + 1}' for i in np.arange(dim_x)]

data = pd.DataFrame(np.column_stack((id, time, d_it, y_it, x_it)),
columns=['id', 'time', 'd', 'y'] + x_cols).astype({'id': 'int64', 'time': 'int64'})
data = pd.DataFrame(np.column_stack((id, time, y_it, d_it, x_it)),
columns=['id', 'time', 'y', 'd'] + x_cols).astype({'id': 'int64', 'time': 'int64'})

return data
93 changes: 51 additions & 42 deletions doubleml/plm/plpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
or a callable object / function with signature ``psi_a, psi_b = score(y, d, l_hat, m_hat, g_hat, smpls)``.
Default is ``'partialling out'``.

static_panel_approach : str
approach : str
A str (``'cre_general'``, ``'cre_normal'``, ``'fd_exact'``, ``'wg_approx'``) specifying the type of
static panel approach in Clarke and Polselli (2025).
Default is ``'fd_exact'``.
Expand All @@ -74,17 +74,17 @@ class DoubleMLPLPR(LinearScoreMixin, DoubleML):
"""

def __init__(
self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", static_panel_approach="fd_exact", draw_sample_splitting=True):
self, obj_dml_data, ml_l, ml_m, ml_g=None, n_folds=5, n_rep=1, score="partialling out", approach="fd_exact", draw_sample_splitting=True):
super().__init__(obj_dml_data, n_folds, n_rep, score, draw_sample_splitting)

self._check_data(self._dml_data)
# TODO: assert cluster?
valid_scores = ["IV-type", "partialling out"]
_check_score(self.score, valid_scores, allow_callable=True)

valid_static_panel_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
self._check_static_panel_approach(static_panel_approach, valid_static_panel_approach)
self._static_panel_approach = static_panel_approach
valid_approach = ["cre_general", "cre_normal", "fd_exact", "wg_approx"]
self._check_approach(approach, valid_approach)
self._approach = approach

_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
Expand Down Expand Up @@ -127,15 +127,15 @@ def __init__(
# Get transformed data depending on approach
# TODO: get y, x, d cols, set additional properties for y_data, d_data, x_data to be used in
# nuisance
self._data_transform = self._transform_data(self._static_panel_approach)
self._data_transform, self._transform_col_names, = self._transform_data(self._approach)


def _format_score_info_str(self):
score_static_panel_approach_info = (
score_approach_info = (
f"Score function: {str(self.score)}\n"
f"Static panel model approach: {str(self.static_panel_approach)}"
f"Static panel model approach: {str(self.approach)}"
)
return score_static_panel_approach_info
return score_approach_info

def _format_additional_info_str(self):
"""
Expand All @@ -145,19 +145,26 @@ def _format_additional_info_str(self):
return ""

@property
def static_panel_approach(self):
def approach(self):
"""
The score function.
"""
return self._static_panel_approach
return self._approach

@property
def data_transform(self):
"""
The transformed static panel data.
"""
return self._data_transform


@property
def transform_col_names(self):
"""
The column names of the transformed static panel data.
"""
return self._transform_col_names

def _initialize_ml_nuisance_params(self):
self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in self._learner}

Expand All @@ -177,16 +184,16 @@ def _check_data(self, obj_dml_data):
)
return

def _check_static_panel_approach(self, static_panel_approach, valid_static_panel_approach):
if isinstance(static_panel_approach, str):
if static_panel_approach not in valid_static_panel_approach:
raise ValueError("Invalid static_panel_approach " + static_panel_approach + ". " + "Valid approach " + " or ".join(valid_static_panel_approach) + ".")
def _check_approach(self, approach, valid_approach):
if isinstance(approach, str):
if approach not in valid_approach:
raise ValueError("Invalid approach " + approach + ". " + "Valid approach " + " or ".join(valid_approach) + ".")
else:
raise TypeError(f"static_panel_approach should be a string. {str(static_panel_approach)} was passed.")
raise TypeError(f"approach should be a string. {str(approach)} was passed.")
return

# TODO: preprocess and transform data based on static_panel_approach (cre, fd, wd)
def _transform_data(self, static_panel_approach):
# TODO: preprocess and transform data based on approach (cre, fd, wd)
def _transform_data(self, approach):
df = self._dml_data.data.copy()

y_col = self._dml_data.y_col
Expand All @@ -195,34 +202,36 @@ def _transform_data(self, static_panel_approach):
t_col = self._dml_data.t_col
id_col = self._dml_data.id_col

if static_panel_approach in ["cre_general", "cre_normal"]:
# uses regular y_col, d_cols, x_cols + m_x_cols
df_id_means = df[[id_col] + d_cols + x_cols].groupby(id_col).transform("mean")
df_means = df_id_means.add_prefix("mean_")
if approach in ["cre_general", "cre_normal"]:
df_id_means = df[[id_col] + x_cols].groupby(id_col).transform("mean")
df_means = df_id_means.add_suffix("_mean")
data = pd.concat([df, df_means], axis=1)
# {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"m_{x}"" for x in x_cols]}
elif static_panel_approach == "fd_exact":
col_names = {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols + [f"{x}_mean" for x in x_cols]}
elif approach == "fd_exact":
# TODO: potential issues with unbalanced panels/missing periods, right now the
# last available is used as the lag and for diff. Maybe reindex to a complete time grid per id.
# uses y_col_diff, d_cols_diff, x_cols + x_cols_lag
# last available is used for the lag and first difference. Maybe reindex to a complete time grid per id.
df = df.sort_values([id_col, t_col])
shifted = df[[id_col] + x_cols].groupby(id_col).shift(1).add_suffix("_lag")
first_diff = df[[id_col] + d_cols + [y_col]].groupby(id_col).diff().add_suffix("_diff")
df_fd = pd.concat([df, shifted, first_diff], axis=1)
first_diff = df[[id_col] + [y_col] + d_cols].groupby(id_col).diff().add_suffix("_diff")
df_fd = pd.concat([df, shifted], axis=1)
# replace original y and d columns for first-difference transformations, rename
df_fd[[y_col] + d_cols] = first_diff
cols_rename_dict = {y_col: f"{y_col}_diff"} | {col: f"{col}_diff" for col in d_cols}
df_fd = df_fd.rename(columns=cols_rename_dict)
# drop rows for first period
data = df_fd.dropna(subset=[x_cols[0] + "_lag"]).reset_index(drop=True)
# {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
elif static_panel_approach == "wg_approx":
# uses y_col, d_cols, x_cols
df_demean = df.drop(t_col, axis=1).groupby(id_col).transform(lambda x: x - x.mean())
# add grand means
grand_means = df.drop([id_col, t_col], axis=1).mean()
within_means = df_demean + grand_means
col_names = {"y_col": f"{y_col}_diff", "d_cols": [f"{d}_diff" for d in d_cols], "x_cols": x_cols + [f"{x}_lag" for x in x_cols]}
elif approach == "wg_approx":
cols_to_demean = [y_col] + d_cols + x_cols
# compute group and grand means for within means
group_means = df.groupby(id_col)[cols_to_demean].transform('mean')
grand_means = df[cols_to_demean].mean()
within_means = df[cols_to_demean] - group_means + grand_means
within_means = within_means.add_suffix("_demean")
data = pd.concat([df[[id_col, t_col]], within_means], axis=1)
# {"y_col": y_col, "d_cols": d_cols, "x_cols": x_cols}
else:
raise ValueError(f"Invalid static_panel_approach.")
col_names = {"y_col": f"{y_col}_demean", "d_cols": [f"{d}_demean" for d in d_cols], "x_cols": [f"{x}_demean" for x in x_cols]}

return data
return data, col_names

def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=False):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
Expand Down Expand Up @@ -258,7 +267,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
else:
# TODO: update this section
# cre using m_d + x for m_hat, otherwise only x
if self._static_panel_approach == "cre_normal":
if self._approach == "cre_normal":
help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "d": d})
m_d = help_data.groupby(["id"]).transform("mean").values
x = np.column_stack((x, m_d))
Expand All @@ -275,7 +284,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
)

# general cre adjustment
if self._static_panel_approach == "cre_general":
if self._approach == "cre_general":
help_data = pd.DataFrame({"id": self._dml_data.cluster_vars[:, 0], "m_hat": m_hat["preds"], "d": d})
group_means = help_data.groupby(["id"])[["m_hat", "d"]].transform("mean")
m_hat_star = m_hat["preds"] + group_means["d"] - group_means["m_hat"]
Expand Down
Loading