Skip to content
Draft
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
fe6375e
static panel model, data, and util
JulianDiefenbacher Jun 23, 2025
2031923
update plpr and dataset
JulianDiefenbacher Jun 26, 2025
5f6a374
add basic simulations
JulianDiefenbacher Jun 26, 2025
215a5e5
add plpr __str__, and checks
JulianDiefenbacher Jul 10, 2025
4189671
update example_sim
JulianDiefenbacher Jul 10, 2025
2c7e238
add model descriptions
JulianDiefenbacher Oct 9, 2025
7708f9b
fix typo
JulianDiefenbacher Oct 9, 2025
d5bf9e9
fix notation consistency
JulianDiefenbacher Oct 9, 2025
d8c3039
update description numbering
JulianDiefenbacher Oct 9, 2025
375325f
Merge branch 'main' into j-static-panel
JulianDiefenbacher Oct 27, 2025
753f68a
update from ClusterData to base Data class
JulianDiefenbacher Oct 27, 2025
058da4e
add static_panel flag in PanelData
JulianDiefenbacher Oct 27, 2025
85da2eb
add static_panel property
JulianDiefenbacher Nov 4, 2025
08cfd8c
add static_panel property and update tests for panel data handling
SvenKlaassen Nov 4, 2025
1ecab66
update plpr model, include data transformation
JulianDiefenbacher Nov 6, 2025
353e148
Merge branch 'main' into j-static-panel
JulianDiefenbacher Nov 6, 2025
5dc1a44
refactor: simplify string representation and add additional info meth…
SvenKlaassen Nov 7, 2025
fdc4330
correct score info string newline spacing
JulianDiefenbacher Nov 10, 2025
bccc81c
data transform update, add transform_col_names property
JulianDiefenbacher Nov 10, 2025
eb68557
add transformed data arrays for nuisance estimation
JulianDiefenbacher Nov 10, 2025
c404f06
add _initialize_fd_model because of n_obs and smpls issue
JulianDiefenbacher Nov 11, 2025
3ccfb34
clearer TODO description
JulianDiefenbacher Nov 11, 2025
ab9f83f
move data transformation before init
JulianDiefenbacher Nov 13, 2025
f27ebeb
update logic for cre_normal approach in estimation and tuning
JulianDiefenbacher Nov 13, 2025
011f8cb
add simulation replication
JulianDiefenbacher Nov 24, 2025
3f8eef6
update PLPR model
JulianDiefenbacher Nov 24, 2025
677c5a6
allow binary treatment for PLPR
JulianDiefenbacher Nov 24, 2025
7c7bd43
update plpr dgp function
JulianDiefenbacher Nov 27, 2025
ecce2b5
update make_plpr use
JulianDiefenbacher Nov 27, 2025
122e97d
Merge branch 'main' into j-static-panel
JulianDiefenbacher Nov 27, 2025
453cbf6
add basic plpr tests
JulianDiefenbacher Nov 30, 2025
86c947a
remove notebooks
JulianDiefenbacher Dec 1, 2025
7d01cd1
fix id var issue
JulianDiefenbacher Dec 1, 2025
7c24a99
complete plpr dataset tests
JulianDiefenbacher Dec 2, 2025
fb56a27
fix formatting
JulianDiefenbacher Dec 2, 2025
b15571c
add external pred test, complete model default test
JulianDiefenbacher Dec 3, 2025
7e72ad2
remove callable score option for plpr
JulianDiefenbacher Dec 4, 2025
630e736
Merge branch 'main' into j-static-panel
JulianDiefenbacher Dec 5, 2025
85eee84
Merge branch 'main' into j-static-panel
SvenKlaassen Dec 5, 2025
47b929e
add _nuisance_tuning_optuna to plpr with tests
SvenKlaassen Dec 5, 2025
e4fc658
refactor test_doubleml_plr_optuna_tune to use score parameter and imp…
SvenKlaassen Dec 5, 2025
dcf9e71
increase sample size in test_doubleml_lplr_optuna_tune to enable robu…
SvenKlaassen Dec 5, 2025
f5c6c50
refactor _nuisance_est method signature to reorder parameters for cla…
SvenKlaassen Dec 5, 2025
9e65836
Merge branch 'j-static-panel' of https://github.com/DoubleML/doubleml…
JulianDiefenbacher Dec 5, 2025
fcc7195
change dml import
JulianDiefenbacher Dec 5, 2025
19ef9b5
check plpr test dml import
JulianDiefenbacher Dec 5, 2025
85ac102
change dml import to avoid codacy fails
JulianDiefenbacher Dec 5, 2025
4e63915
adjust summary new line and additioanl info header spacing
JulianDiefenbacher Dec 8, 2025
bb1167c
add x_dim check to plpr data
JulianDiefenbacher Dec 8, 2025
0fe1c1b
add comparison of standard errors
JulianDiefenbacher Dec 8, 2025
d32f339
add test_plpr_binary_treatment.py
JulianDiefenbacher Dec 8, 2025
66ca454
add exceptions tests for non-panel and non-static panel data
JulianDiefenbacher Dec 8, 2025
eb75e1d
add plpr return types test and enhance return type checks for cluster…
JulianDiefenbacher Dec 8, 2025
ef6a3c9
add test for plpr data transformations
JulianDiefenbacher Dec 8, 2025
1df9cf2
improve documentation and improve data handling in plpr class
JulianDiefenbacher Dec 8, 2025
975b425
add validation for time variable data type in DoubleMLPanelData
JulianDiefenbacher Dec 9, 2025
6fd77f1
fix: correct test description for overlapping variables in DoubleMLPa…
JulianDiefenbacher Dec 9, 2025
e373473
add time_type parameter to make_plpr_CP2025 and enhance validation in…
JulianDiefenbacher Dec 9, 2025
7652ab6
update docstring
SvenKlaassen Dec 10, 2025
8094ed1
split variance for external predictions
SvenKlaassen Dec 10, 2025
5a09a67
fix: adjust indentation for math blocks in docstring of make_plpr_CP2025
SvenKlaassen Dec 10, 2025
4a0cc6a
fix: update mathematical notation in make_plpr_CP2025 to use g_0 inst…
SvenKlaassen Dec 10, 2025
6995bde
fix: correct mathematical notation in make_plpr_CP2025 to use '=' ins…
SvenKlaassen Dec 10, 2025
a3362dc
Merge branch 'main' into j-static-panel
SvenKlaassen Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor test_doubleml_plr_optuna_tune to use score parameter and imp…
…rove readability
  • Loading branch information
SvenKlaassen committed Dec 5, 2025
commit e4fc6580f596ee2d5724c895bd9fcf1cb92a8258
62 changes: 25 additions & 37 deletions doubleml/plm/tests/test_plr_tune_ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,29 @@
)


@pytest.fixture(scope="module", params=["partialling out", "IV-type"])
def score(request):
return request.param


@pytest.mark.ci
@pytest.mark.parametrize("sampler_name,optuna_sampler", _SAMPLER_CASES, ids=[case[0] for case in _SAMPLER_CASES])
def test_doubleml_plr_optuna_tune(sampler_name, optuna_sampler):
def test_doubleml_plr_optuna_tune(sampler_name, optuna_sampler, score):
np.random.seed(3141)
alpha = 0.5
dml_data = make_plr_CCDDHNR2018(n_obs=500, dim_x=5, alpha=alpha)

ml_l = DecisionTreeRegressor(random_state=123)
ml_m = DecisionTreeRegressor(random_state=456)
ml_g = DecisionTreeRegressor(random_state=789) if score == "IV-type" else None

dml_plr = dml.DoubleMLPLR(dml_data, ml_l, ml_m, n_folds=2, score="partialling out")
dml_plr = dml.DoubleMLPLR(dml_data, ml_l, ml_m, ml_g=ml_g, n_folds=2, score=score)
dml_plr.fit()
untuned_score = dml_plr.evaluate_learners()

optuna_params = {"ml_l": _small_tree_params, "ml_m": _small_tree_params}
if score == "IV-type":
optuna_params["ml_g"] = _small_tree_params

tune_res = dml_plr.tune_ml_models(
ml_param_space=optuna_params,
Expand All @@ -37,44 +45,24 @@ def test_doubleml_plr_optuna_tune(sampler_name, optuna_sampler):
dml_plr.fit()
tuned_score = dml_plr.evaluate_learners()

tuned_params_l = tune_res[0]["ml_l"].best_params
tuned_params_m = tune_res[0]["ml_m"].best_params

_assert_tree_params(tuned_params_l)
_assert_tree_params(tuned_params_m)
best_param_dict = {
"ml_l": tune_res[0]["ml_l"].best_params,
"ml_m": tune_res[0]["ml_m"].best_params,
}
if score == "IV-type":
best_param_dict["ml_g"] = tune_res[0]["ml_g"].best_params

# ensure results contain optuna objects and best params
assert isinstance(tune_res[0], dict)
assert set(tune_res[0].keys()) == {"ml_l", "ml_m"}
assert hasattr(tune_res[0]["ml_l"], "best_params")
assert tune_res[0]["ml_l"].best_params["max_depth"] == tuned_params_l["max_depth"]
assert hasattr(tune_res[0]["ml_m"], "best_params")
assert tune_res[0]["ml_m"].best_params["max_depth"] == tuned_params_m["max_depth"]
assert set(tune_res[0].keys()) == best_param_dict.keys()

for params_name, params in best_param_dict.items():
_assert_tree_params(params)

assert hasattr(tune_res[0][params_name], "best_params")
assert tune_res[0][params_name].best_params["max_depth"] == params["max_depth"]

# ensure tuning improved RMSE
assert tuned_score["ml_l"] < untuned_score["ml_l"]
assert tuned_score["ml_m"] < untuned_score["ml_m"]


@pytest.mark.ci
def test_doubleml_plr_optuna_tune_with_ml_g():
np.random.seed(3150)
dml_data = make_plr_CCDDHNR2018(n_obs=200, dim_x=5, alpha=0.5)

ml_l = DecisionTreeRegressor(random_state=11)
ml_m = DecisionTreeRegressor(random_state=12)
ml_g = DecisionTreeRegressor(random_state=13)

dml_plr = dml.DoubleMLPLR(dml_data, ml_l, ml_m, ml_g, n_folds=2, score="IV-type")

optuna_params = {"ml_l": _small_tree_params, "ml_m": _small_tree_params, "ml_g": _small_tree_params}

tune_res = dml_plr.tune_ml_models(
ml_param_space=optuna_params,
optuna_settings=_basic_optuna_settings({"n_trials": 1}),
return_tune_res=True,
)

assert "ml_g" in tune_res[0]
ml_g_res = tune_res[0]["ml_g"]
assert ml_g_res.best_params is not None
if score == "IV-type":
assert tuned_score["ml_g"] < untuned_score["ml_g"]