Skip to content

Commit

Permalink
fit_sklearn fit_statsmodels
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Dec 18, 2024
1 parent cc67765 commit bfc1955
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
13 changes: 11 additions & 2 deletions marginaleffects/formulaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def _sanity_engine(engine):


@validate_types
def fit(formula: str, data: pl.DataFrame, engine):
_sanity_engine(engine)
def fit_sklearn(formula: str, data: pl.DataFrame, engine):
y, X, data = design(formula, data)
print(engine)
out = engine.fit(X=X, y=y)
out.formula = formula
out.data = data
return out


@validate_types
def fit_statsmodels(formula:str, data: pl.DataFrame, engine):
y, X, data = design(formula, data)
mod = engine(endog=y, exog=X).fit()
mod.data = data
mod.formula = formula
return mod


__all__ = ["variables", "lwd", "design", "fit"]
16 changes: 12 additions & 4 deletions marginaleffects/model_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

class ModelStatsmodels(ModelAbstract):
def __init__(self, model):
self.formula = model.model.formula
self.data = ingest(model.model.data.frame)
if hasattr(model, "formula"):
self.formula = model.formula
self.data = ingest(model.data)
else:
self.formula = model.model.formula
self.data = ingest(model.model.data.frame)
super().__init__(model)
# after super()
self.formula_engine = "patsy"
self.design_info_patsy = model.model.data.design_info
if hasattr(model, "formula"):
self.formula_engine = "formulaic"
else:
self.formula_engine = "patsy"
self.design_info_patsy = model.model.data.design_info


def get_coef(self):
return np.array(self.model.params)
Expand Down

0 comments on commit bfc1955

Please sign in to comment.