From bfc1955d110a41dd22c8c4d170d5ea1f6e61e07e Mon Sep 17 00:00:00 2001 From: Vincent Arel-Bundock Date: Tue, 17 Dec 2024 23:05:53 -0500 Subject: [PATCH] fit_sklearn fit_statsmodels --- marginaleffects/formulaic.py | 13 +++++++++++-- marginaleffects/model_statsmodels.py | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/marginaleffects/formulaic.py b/marginaleffects/formulaic.py index 7c79ca4..87744e4 100644 --- a/marginaleffects/formulaic.py +++ b/marginaleffects/formulaic.py @@ -66,13 +66,22 @@ def _sanity_engine(engine): @validate_types -def fit(formula: str, data: pl.DataFrame, engine): - _sanity_engine(engine) +def fit_sklearn(formula: str, data: pl.DataFrame, engine): y, X, data = design(formula, data) + print(engine) out = engine.fit(X=X, y=y) out.formula = formula out.data = data return out +@validate_types +def fit_statsmodels(formula:str, data: pl.DataFrame, engine): + y, X, data = design(formula, data) + mod = engine(endog=y, exog=X).fit() + mod.data = data + mod.formula = formula + return mod + + __all__ = ["variables", "lwd", "design", "fit"] diff --git a/marginaleffects/model_statsmodels.py b/marginaleffects/model_statsmodels.py index fbf6885..c24f58f 100644 --- a/marginaleffects/model_statsmodels.py +++ b/marginaleffects/model_statsmodels.py @@ -8,12 +8,20 @@ class ModelStatsmodels(ModelAbstract): def __init__(self, model): - self.formula = model.model.formula - self.data = ingest(model.model.data.frame) + if hasattr(model, "formula"): + self.formula = model.formula + self.data = ingest(model.data) + else: + self.formula = model.model.formula + self.data = ingest(model.model.data.frame) super().__init__(model) # after super() - self.formula_engine = "patsy" - self.design_info_patsy = model.model.data.design_info + if hasattr(model, "formula"): + self.formula_engine = "formulaic" + else: + self.formula_engine = "patsy" + self.design_info_patsy = model.model.data.design_info + def get_coef(self): return np.array(self.model.params)