diff --git a/.gitignore b/.gitignore index 24f7c5cc..93d4dfaf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ +.idea/ # Logs monte-cover/logs/ diff --git a/doc/_quarto-dev.yml b/doc/_quarto-dev.yml index 5c3587ab..b73319b3 100644 --- a/doc/_quarto-dev.yml +++ b/doc/_quarto-dev.yml @@ -21,6 +21,7 @@ website: - plm/plr_gate.qmd - plm/plr_cate.qmd - plm/pliv.qmd + - plm/lplr.qmd # DID - did/did_pa.qmd - did/did_cs.qmd diff --git a/doc/_website.yml b/doc/_website.yml index 4bf06b85..98c2a044 100644 --- a/doc/_website.yml +++ b/doc/_website.yml @@ -25,6 +25,7 @@ website: - plm/plr_gate.qmd - plm/plr_cate.qmd - plm/pliv.qmd + - plm/lplr.qmd - text: "DID" menu: - did/did_pa_multi.qmd diff --git a/doc/plm/lplr.qmd b/doc/plm/lplr.qmd new file mode 100644 index 00000000..200e5782 --- /dev/null +++ b/doc/plm/lplr.qmd @@ -0,0 +1,113 @@ +--- +title: "Logistic Partial Linear Regression Models" + +jupyter: python3 +--- + +```{python} +#| echo: false + +import numpy as np +import pandas as pd +from itables import init_notebook_mode +import os +import sys + +doc_dir = os.path.abspath(os.path.join(os.getcwd(), "..")) +if doc_dir not in sys.path: + sys.path.append(doc_dir) + +from utils.style_tables import generate_and_show_styled_table + +init_notebook_mode(all_interactive=True) +``` + +## ATE Coverage + +The simulations are based on the the [make_lplr_LZZ2020](https://docs.doubleml.org/stable/api/generated/doubleml.datasets.make_lplr_LZZ2020.html)-DGP with $500$ observations. + +::: {.callout-note title="Metadata" collapse="true"} + +```{python} +#| echo: false +metadata_file = '../../results/plm/lplr_ate_metadata.csv' +metadata_df = pd.read_csv(metadata_file) +print(metadata_df.T.to_string(header=False)) +``` + +::: + +```{python} +#| echo: false + +# set up data and rename columns +df_coverage = pd.read_csv("../../results/plm/lplr_ate_coverage.csv", index_col=None) + +if "repetition" in df_coverage.columns and df_coverage["repetition"].nunique() == 1: + n_rep_coverage = df_coverage["repetition"].unique()[0] +elif "n_rep" in df_coverage.columns and df_coverage["n_rep"].nunique() == 1: + n_rep_coverage = df_coverage["n_rep"].unique()[0] +else: + n_rep_coverage = "N/A" # Fallback if n_rep cannot be determined + +display_columns_coverage = ["Learner m", "Learner M", "Learner t", "Bias", "CI Length", "Coverage"] +``` + +### Nuisance space + +```{python} +# | echo: false + +generate_and_show_styled_table( + main_df=df_coverage, + filters={"level": 0.95, "Score": "nuisance_space"}, + display_cols=display_columns_coverage, + n_rep=n_rep_coverage, + level_col="level", +# rename_map={"Learner g": "Learner l"}, + coverage_highlight_cols=["Coverage"] +) +``` + +```{python} +#| echo: false + +generate_and_show_styled_table( + main_df=df_coverage, + filters={"level": 0.9, "Score": "nuisance_space"}, + display_cols=display_columns_coverage, + n_rep=n_rep_coverage, + level_col="level", +# rename_map={"Learner g": "Learner l"}, + coverage_highlight_cols=["Coverage"] +) +``` + +### Instrument + + +```{python} +#| echo: false + +generate_and_show_styled_table( + main_df=df_coverage, + filters={"level": 0.95, "Score": "instrument"}, + display_cols=display_columns_coverage, + n_rep=n_rep_coverage, + level_col="level", + coverage_highlight_cols=["Coverage"] +) +``` + +```{python} +#| echo: false + +generate_and_show_styled_table( + main_df=df_coverage, + filters={"level": 0.9, "Score": "instrument"}, + display_cols=display_columns_coverage, + n_rep=n_rep_coverage, + level_col="level", + coverage_highlight_cols=["Coverage"] +) +``` \ No newline at end of file diff --git a/monte-cover/src/montecover/plm/__init__.py b/monte-cover/src/montecover/plm/__init__.py index 167b36d8..5d995c92 100644 --- a/monte-cover/src/montecover/plm/__init__.py +++ b/monte-cover/src/montecover/plm/__init__.py @@ -5,6 +5,7 @@ from montecover.plm.plr_ate_sensitivity import PLRATESensitivityCoverageSimulation from montecover.plm.plr_cate import PLRCATECoverageSimulation from montecover.plm.plr_gate import PLRGATECoverageSimulation +from montecover.plm.lplr_ate import LPLRATECoverageSimulation __all__ = [ "PLRATECoverageSimulation", @@ -12,4 +13,5 @@ "PLRGATECoverageSimulation", "PLRCATECoverageSimulation", "PLRATESensitivityCoverageSimulation", + "LPLRATECoverageSimulation", ] diff --git a/monte-cover/src/montecover/plm/lplr_ate.py b/monte-cover/src/montecover/plm/lplr_ate.py new file mode 100644 index 00000000..da962e32 --- /dev/null +++ b/monte-cover/src/montecover/plm/lplr_ate.py @@ -0,0 +1,126 @@ +import warnings +from typing import Any, Dict, Optional + +import doubleml as dml +from doubleml.plm.datasets import make_lplr_LZZ2020 + +from montecover.base import BaseSimulation +from montecover.utils import create_learner_from_config + + +class LPLRATECoverageSimulation(BaseSimulation): + """Simulation class for coverage properties of DoubleMLPLR for ATE estimation.""" + + def __init__( + self, + config_file: str, + suppress_warnings: bool = True, + log_level: str = "INFO", + log_file: Optional[str] = None, + use_failed_scores: bool = False, + ): + super().__init__( + config_file=config_file, + suppress_warnings=suppress_warnings, + log_level=log_level, + log_file=log_file, + ) + + # Calculate oracle values + self._calculate_oracle_values() + + self._use_failed_scores = use_failed_scores + + def _process_config_parameters(self): + """Process simulation-specific parameters from config""" + # Process ML models in parameter grid + assert "learners" in self.dml_parameters, "No learners specified in the config file" + + required_learners = ["ml_m", "ml_M", "ml_t"] + for learner in self.dml_parameters["learners"]: + for ml in required_learners: + assert ml in learner, f"No {ml} specified in the config file" + + def _calculate_oracle_values(self): + """Calculate oracle values for the simulation.""" + self.logger.info("Calculating oracle values") + + self.oracle_values = dict() + self.oracle_values["theta"] = self.dgp_parameters["theta"] + + def run_single_rep(self, dml_data, dml_params) -> Dict[str, Any]: + """Run a single repetition with the given parameters.""" + # Extract parameters + learner_config = dml_params["learners"] + learner_m_name, ml_m = create_learner_from_config(learner_config["ml_m"]) + learner_M_name, ml_M = create_learner_from_config(learner_config["ml_M"]) + learner_t_name, ml_t = create_learner_from_config(learner_config["ml_t"]) + score = dml_params["score"] + + # Model + dml_model = dml.DoubleMLLPLR( + obj_dml_data=dml_data, + ml_m=ml_m, + ml_M=ml_M, + ml_t=ml_t, + score=score, + error_on_convergence_failure= not self._use_failed_scores,) + + try: + dml_model.fit() + except RuntimeError as e: + self.logger.info(f"Exception during fit: {e}") + return None + + result = { + "coverage": [], + } + for level in self.confidence_parameters["level"]: + level_result = dict() + level_result["coverage"] = self._compute_coverage( + thetas=dml_model.coef, + oracle_thetas=self.oracle_values["theta"], + confint=dml_model.confint(level=level), + joint_confint=None, + ) + + # add parameters to the result + for res in level_result.values(): + res.update( + { + "Learner m": learner_m_name, + "Learner M": learner_M_name, + "Learner t": learner_t_name, + "Score": score, + "level": level, + } + ) + for key, res in level_result.items(): + result[key].append(res) + + return result + + def summarize_results(self): + """Summarize the simulation results.""" + self.logger.info("Summarizing simulation results") + + # Group by parameter combinations + groupby_cols = ["Learner m", "Learner M", "Learner t", "Score", "level"] + aggregation_dict = { + "Coverage": "mean", + "CI Length": "mean", + "Bias": "mean", + "repetition": "count", + } + + # Aggregate results (possibly multiple result dfs) + result_summary = dict() + for result_name, result_df in self.results.items(): + result_summary[result_name] = result_df.groupby(groupby_cols).agg(aggregation_dict).reset_index() + self.logger.debug(f"Summarized {result_name} results") + + return result_summary + + def _generate_dml_data(self, dgp_params) -> dml.DoubleMLData: + """Generate data for the simulation.""" + return make_lplr_LZZ2020(**dgp_params) diff --git a/results/plm/logistic_ate_config.yml b/results/plm/logistic_ate_config.yml new file mode 100644 index 00000000..b203b920 --- /dev/null +++ b/results/plm/logistic_ate_config.yml @@ -0,0 +1,91 @@ +simulation_parameters: + repetitions: 1000 + max_runtime: 86400 + random_seed: 42 + n_jobs: -2 +dgp_parameters: + theta: + - 0.5 + n_obs: + - 500 + dim_x: + - 20 +learner_definitions: + lasso: &id001 + name: LassoCV + logistic: &id002 + name: Logistic + rf: &id003 + name: RF Regr. + params: + n_estimators: 100 + max_features: sqrt + rf-class: &id004 + name: RF Clas. + params: + n_estimators: 100 + max_features: sqrt + lgbm: &id005 + name: LGBM Regr. + params: + n_estimators: 500 + learning_rate: 0.01 + lgbm-class: &id006 + name: LGBM Clas. + params: + n_estimators: 500 + learning_rate: 0.01 +dml_parameters: + learners: + - ml_m: *id001 + ml_M: *id002 + ml_t: *id001 + - ml_m: *id003 + ml_M: *id004 + ml_t: *id003 + - ml_m: *id005 + ml_M: *id006 + ml_t: *id005 + - ml_m: *id003 + ml_M: *id006 + ml_t: *id005 + - ml_m: *id005 + ml_M: *id004 + ml_t: *id005 + - ml_m: *id005 + ml_M: *id006 + ml_t: *id003 + - ml_m: *id005 + ml_M: *id004 + ml_t: *id003 + - ml_m: *id003 + ml_M: *id006 + ml_t: *id003 + - ml_m: *id003 + ml_M: *id004 + ml_t: *id005 + - ml_m: *id001 + ml_M: *id006 + ml_t: *id005 + - ml_m: *id005 + ml_M: *id002 + ml_t: *id005 + - ml_m: *id005 + ml_M: *id006 + ml_t: *id001 + - ml_m: *id001 + ml_M: *id004 + ml_t: *id003 + - ml_m: *id003 + ml_M: *id002 + ml_t: *id003 + - ml_m: *id003 + ml_M: *id004 + ml_t: *id001 + score: + - nuisance_space + - instrument +confidence_parameters: + level: + - 0.95 + - 0.9 diff --git a/results/plm/logistic_ate_coverage.csv b/results/plm/logistic_ate_coverage.csv new file mode 100644 index 00000000..920c3cf8 --- /dev/null +++ b/results/plm/logistic_ate_coverage.csv @@ -0,0 +1,61 @@ +Learner m,Learner M,Learner t,Score,level,Coverage,CI Length,Bias,repetition +LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,0.8867735470941884,0.6783720219284418,0.17182702238154213,998 +LGBM Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,0.9458917835671342,0.8083301208774294,0.17182702238154213,998 +LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,0.886,0.5883608609896965,0.1546569991698314,1000 +LGBM Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,0.942,0.7010752072754521,0.1546569991698314,1000 +LGBM Regr.,LGBM Clas.,LassoCV,instrument,0.9,0.8856569709127382,0.687914636116578,0.17843968090261725,997 +LGBM Regr.,LGBM Clas.,LassoCV,instrument,0.95,0.9398194583751254,0.819700847014181,0.17843968090261725,997 +LGBM Regr.,LGBM Clas.,LassoCV,nuisance_space,0.9,0.853,0.613277414594929,0.17455974016950299,1000 +LGBM Regr.,LGBM Clas.,LassoCV,nuisance_space,0.95,0.922,0.7307651121307722,0.17455974016950299,1000 +LGBM Regr.,LGBM Clas.,RF Regr.,instrument,0.9,0.833,0.6645257584558233,0.1981803920481237,1000 +LGBM Regr.,LGBM Clas.,RF Regr.,instrument,0.95,0.913,0.7918312803227949,0.1981803920481237,1000 +LGBM Regr.,LGBM Clas.,RF Regr.,nuisance_space,0.9,0.749,0.6389887792744618,0.2310882489727634,1000 +LGBM Regr.,LGBM Clas.,RF Regr.,nuisance_space,0.95,0.847,0.7614020927955242,0.2310882489727634,1000 +LGBM Regr.,Logistic,LGBM Regr.,instrument,0.9,0.8808808808808809,0.6011544597174262,0.15730144394486342,999 +LGBM Regr.,Logistic,LGBM Regr.,instrument,0.95,0.9269269269269269,0.7163197204212697,0.15730144394486342,999 +LGBM Regr.,Logistic,LGBM Regr.,nuisance_space,0.9,0.802,0.533982278217265,0.1735015501567642,1000 +LGBM Regr.,Logistic,LGBM Regr.,nuisance_space,0.95,0.893,0.6362791293643562,0.1735015501567642,1000 +LGBM Regr.,RF Clas.,LGBM Regr.,instrument,0.9,0.8808808808808809,0.6117037321129385,0.14924058625395906,999 +LGBM Regr.,RF Clas.,LGBM Regr.,instrument,0.95,0.938938938938939,0.7288899537961552,0.14924058625395906,999 +LGBM Regr.,RF Clas.,LGBM Regr.,nuisance_space,0.9,0.887,0.5255256282131954,0.12946206156000842,1000 +LGBM Regr.,RF Clas.,LGBM Regr.,nuisance_space,0.95,0.948,0.6262024093655342,0.12946206156000842,1000 +LGBM Regr.,RF Clas.,RF Regr.,instrument,0.9,0.893,0.6133564813843166,0.15711608477124128,1000 +LGBM Regr.,RF Clas.,RF Regr.,instrument,0.95,0.943,0.7308593260213176,0.15711608477124128,1000 +LGBM Regr.,RF Clas.,RF Regr.,nuisance_space,0.9,0.86,0.5540472193413977,0.15675464483344737,1000 +LGBM Regr.,RF Clas.,RF Regr.,nuisance_space,0.95,0.935,0.6601879813806316,0.15675464483344737,1000 +LassoCV,LGBM Clas.,LGBM Regr.,instrument,0.9,0.8062563067608476,0.6448097763855765,0.19653637418785105,991 +LassoCV,LGBM Clas.,LGBM Regr.,instrument,0.95,0.8890010090817356,0.7683382386658661,0.19653637418785105,991 +LassoCV,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,0.72165991902834,0.5619651019188039,0.19918381058581103,988 +LassoCV,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,0.840080971659919,0.6696227203940329,0.19918381058581103,988 +LassoCV,Logistic,LassoCV,instrument,0.9,0.9126506024096386,0.6493687054509357,0.15965331285568357,996 +LassoCV,Logistic,LassoCV,instrument,0.95,0.9618473895582329,0.7737705377043753,0.15965331285568357,996 +LassoCV,Logistic,LassoCV,nuisance_space,0.9,0.8682092555331992,0.5768393638614188,0.1458288654760023,994 +LassoCV,Logistic,LassoCV,nuisance_space,0.95,0.9356136820925554,0.6873464966781094,0.1458288654760023,994 +LassoCV,RF Clas.,RF Regr.,instrument,0.9,0.8667334669338678,0.5890487369844828,0.14213629243588016,998 +LassoCV,RF Clas.,RF Regr.,instrument,0.95,0.93687374749499,0.7018948620784813,0.14213629243588016,998 +LassoCV,RF Clas.,RF Regr.,nuisance_space,0.9,0.8908908908908909,0.5583249926493753,0.13040987029805642,999 +LassoCV,RF Clas.,RF Regr.,nuisance_space,0.95,0.9369369369369369,0.6652852626707622,0.13040987029805642,999 +RF Regr.,LGBM Clas.,LGBM Regr.,instrument,0.9,0.883,0.4286586066458282,0.10700456800013383,1000 +RF Regr.,LGBM Clas.,LGBM Regr.,instrument,0.95,0.939,0.510778233955119,0.10700456800013383,1000 +RF Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.9,0.798,0.3832967523848996,0.11829755780901112,1000 +RF Regr.,LGBM Clas.,LGBM Regr.,nuisance_space,0.95,0.871,0.45672625074725515,0.11829755780901112,1000 +RF Regr.,LGBM Clas.,RF Regr.,instrument,0.9,0.866,0.42225079909506574,0.11434483968291848,1000 +RF Regr.,LGBM Clas.,RF Regr.,instrument,0.95,0.919,0.5031428603184782,0.11434483968291848,1000 +RF Regr.,LGBM Clas.,RF Regr.,nuisance_space,0.9,0.881,0.41648308996281536,0.10985709399222088,1000 +RF Regr.,LGBM Clas.,RF Regr.,nuisance_space,0.95,0.938,0.49627021099133717,0.10985709399222088,1000 +RF Regr.,Logistic,RF Regr.,instrument,0.9,0.856,0.38502789712056834,0.10721182765222284,1000 +RF Regr.,Logistic,RF Regr.,instrument,0.95,0.92,0.45878903692977124,0.10721182765222284,1000 +RF Regr.,Logistic,RF Regr.,nuisance_space,0.9,0.824,0.3771933481281758,0.11331805384094351,1000 +RF Regr.,Logistic,RF Regr.,nuisance_space,0.95,0.9,0.4494535960074909,0.11331805384094351,1000 +RF Regr.,RF Clas.,LGBM Regr.,instrument,0.9,0.828,0.38946263148586363,0.11262093701887263,1000 +RF Regr.,RF Clas.,LGBM Regr.,instrument,0.95,0.884,0.46407334885550183,0.11262093701887263,1000 +RF Regr.,RF Clas.,LGBM Regr.,nuisance_space,0.9,0.804,0.36190660207697933,0.10722868220974552,1000 +RF Regr.,RF Clas.,LGBM Regr.,nuisance_space,0.95,0.867,0.4312383145926426,0.10722868220974552,1000 +RF Regr.,RF Clas.,LassoCV,instrument,0.9,0.859,0.39360445751539874,0.10201463510531926,1000 +RF Regr.,RF Clas.,LassoCV,instrument,0.95,0.922,0.4690086389719632,0.10201463510531926,1000 +RF Regr.,RF Clas.,LassoCV,nuisance_space,0.9,0.847,0.37185525976227807,0.097545400580116,1000 +RF Regr.,RF Clas.,LassoCV,nuisance_space,0.95,0.905,0.44309287139830933,0.097545400580116,1000 +RF Regr.,RF Clas.,RF Regr.,instrument,0.9,0.885,0.3931395611851874,0.09840536307939636,1000 +RF Regr.,RF Clas.,RF Regr.,instrument,0.95,0.94,0.4684546808270991,0.09840536307939636,1000 +RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.9,0.877,0.3834497709276788,0.09720459767352349,1000 +RF Regr.,RF Clas.,RF Regr.,nuisance_space,0.95,0.934,0.4569085835870289,0.09720459767352349,1000 diff --git a/results/plm/logistic_ate_metadata.csv b/results/plm/logistic_ate_metadata.csv new file mode 100644 index 00000000..eead6aa7 --- /dev/null +++ b/results/plm/logistic_ate_metadata.csv @@ -0,0 +1,3 @@ +DoubleML Version,Script,Date,Total Runtime (minutes),Python Version,Config File +0.10.dev0,LogisticATECoverageSimulation,2025-09-03 22:35,447.33407898743945,3.12.9,scripts/plm/logistic_ate_config.yml +0.10.dev0,LogisticATECoverageSimulation,2025-09-03 14:16,0.4242911458015442,3.12.11,scripts/plm/logistic_ate_config.yml diff --git a/scripts/plm/lplr_ate.py b/scripts/plm/lplr_ate.py new file mode 100644 index 00000000..a98b2d46 --- /dev/null +++ b/scripts/plm/lplr_ate.py @@ -0,0 +1,14 @@ +from montecover.plm import LPLRATECoverageSimulation + +# Create and run simulation with config file +sim = LPLRATECoverageSimulation( + config_file="scripts/plm/lplr_ate_config.yml", + log_level="INFO", + log_file="logs/plm/plr_ate_sim.log", +) +print("Calling file") +sim.run_simulation() +sim.save_results(output_path="results/plm/", file_prefix="lplr_ate") + +# Save config file for reproducibility +sim.save_config("results/plm/lplr_ate_config.yml") \ No newline at end of file diff --git a/scripts/plm/lplr_ate_config.yml b/scripts/plm/lplr_ate_config.yml new file mode 100644 index 00000000..da804ed9 --- /dev/null +++ b/scripts/plm/lplr_ate_config.yml @@ -0,0 +1,98 @@ +# Simulation parameters for LPLR ATE Coverage + +simulation_parameters: + repetitions: 1000 + max_runtime: 86400 # 24 hours in seconds + random_seed: 42 + n_jobs: -2 + +dgp_parameters: + theta: [0.5] # Treatment effect + n_obs: [500] # Sample size + dim_x: [20] # Number of covariates + balanced_r0: [False] # Whether to use balanced r0 function + +# Define reusable learner configurations +learner_definitions: + lasso: &lasso + name: "LassoCV" + + logistic: &logistic + name: "Logistic" + + rf: &rf + name: "RF Regr." + params: + n_estimators: 100 + max_features: "sqrt" + + rf-class: &rf-class + name: "RF Clas." + params: + n_estimators: 100 + max_features: "sqrt" + + lgbm: &lgbm + name: "LGBM Regr." + params: + n_estimators: 500 + learning_rate: 0.01 + + lgbm-class: &lgbm-class + name: "LGBM Clas." + params: + n_estimators: 500 + learning_rate: 0.01 + +dml_parameters: + learners: + - ml_m: *lasso + ml_M: *logistic + ml_t: *lasso + - ml_m: *rf + ml_M: *rf-class + ml_t: *rf + - ml_m: *lgbm + ml_M: *lgbm-class + ml_t: *lgbm +# - ml_m: *rf +# ml_M: *lgbm-class +# ml_t: *lgbm +# - ml_m: *lgbm +# ml_M: *rf-class +# ml_t: *lgbm +# - ml_m: *lgbm +# ml_M: *lgbm-class +# ml_t: *rf +# - ml_m: *lgbm +# ml_M: *rf-class +# ml_t: *rf +# - ml_m: *rf +# ml_M: *lgbm-class +# ml_t: *rf +# - ml_m: *rf +# ml_M: *rf-class +# ml_t: *lgbm +# - ml_m: *lasso +# ml_M: *lgbm-class +# ml_t: *lgbm +# - ml_m: *lgbm +# ml_M: *logistic +# ml_t: *lgbm +# - ml_m: *lgbm +# ml_M: *lgbm-class +# ml_t: *lasso +# - ml_m: *lasso +# ml_M: *rf-class +# ml_t: *rf +# - ml_m: *rf +# ml_M: *logistic +# ml_t: *rf +# - ml_m: *rf +# ml_M: *rf-class +# ml_t: *lasso + + score: ["nuisance_space", "instrument"] + +confidence_parameters: + level: [0.95, 0.90] # Confidence levels