Skip to content

Commit

Permalink
fit_statsmodels in the statsmodels file
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Dec 18, 2024
1 parent bfc1955 commit 4835fc2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 57 deletions.
4 changes: 4 additions & 0 deletions marginaleffects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .plot_slopes import plot_slopes
from .predictions import avg_predictions, predictions
from .slopes import avg_slopes, slopes
from .model_statsmodels import fit_statsmodels
from .model_sklearn import fit_sklearn

__all__ = [
"avg_comparisons",
Expand All @@ -19,4 +21,6 @@
"predictions",
"avg_slopes",
"slopes",
"fit_statsmodels",
"fit_sklearn",
]
58 changes: 5 additions & 53 deletions marginaleffects/formulaic.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,23 @@
import formulaic
import inspect
import polars as pl
import numpy as np
from .utils import validate_types, ingest
from .utils import validate_types


@validate_types
def variables(formula: str):
if "~" not in formula:
raise ValueError(
"formula must contain '~' to separate dependent and independent variables"
)
tok = formulaic.parser.DefaultFormulaParser().get_tokens(formula)
tok = [t for t in tok if t.kind.value == "name"]
tok = [str(t) for t in tok]
return tok


@validate_types
def lwd(
formula: str | None = None, vars: list[str] | None = None, data: pl.DataFrame = None
):
if formula is not None:
vars = variables(formula)
elif formula is None and vars is None:
raise ValueError("formula or vars must be provided")
def listwise_deletion(formula: str, data: pl.DataFrame):
vars = variables(formula)
return data.drop_nulls(subset=vars)


def model_matrices(formula: str, data: pl.DataFrame, formula_engine="formulaic"):
def model_matrices(formula: str, data: pl.DataFrame, formula_engine: str = "formulaic"):
if formula_engine == "formulaic":
endog, exog = formulaic.model_matrix(formula, data.to_pandas())
endog = endog.to_numpy()
Expand All @@ -47,41 +36,4 @@ def model_matrices(formula: str, data: pl.DataFrame, formula_engine="formulaic")
return None, exog


@validate_types
def design(formula: str, data: pl.DataFrame):
vars = variables(formula)
data = data.drop_nulls(subset=vars)
y, X = formulaic.model_matrix(formula, ingest(data).to_pandas())
y = np.ravel(data[vars[0]]) # avoid matrix if LHS is a character
return y, X, data


def _sanity_engine(engine):
if not hasattr(engine, "fit"):
raise AttributeError("engine must have a 'fit' method")
sig = inspect.signature(engine.fit)
param_names = list(sig.parameters.keys())
if "X" not in param_names or "y" not in param_names:
raise ValueError("engine.fit must accept parameters named 'X' and 'y'")


@validate_types
def fit_sklearn(formula: str, data: pl.DataFrame, engine):
y, X, data = design(formula, data)
print(engine)
out = engine.fit(X=X, y=y)
out.formula = formula
out.data = data
return out


@validate_types
def fit_statsmodels(formula:str, data: pl.DataFrame, engine):
y, X, data = design(formula, data)
mod = engine(endog=y, exog=X).fit()
mod.data = data
mod.formula = formula
return mod


__all__ = ["variables", "lwd", "design", "fit"]
__all__ = ["listwise_deletion", "model_matrices"]
24 changes: 23 additions & 1 deletion marginaleffects/model_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import warnings
import polars as pl
from .utils import ingest
from .utils import validate_types, ingest
from .formulaic import listwise_deletion, model_matrices, variables
from .model_abstract import ModelAbstract


Expand Down Expand Up @@ -52,6 +53,10 @@ def get_predict(self, params, newdata: pl.DataFrame):

if p.ndim == 1:
p = pl.DataFrame({"rowid": range(newdata.shape[0]), "estimate": p})
elif p.ndim == 2 and p.shape[1] == 1:
p = pl.DataFrame(
{"rowid": range(newdata.shape[0]), "estimate": np.ravel(p)}
)
elif p.ndim == 2:
colnames = {f"column_{i}": v for i, v in enumerate(self.model.classes_)}
p = (
Expand All @@ -70,3 +75,20 @@ def get_predict(self, params, newdata: pl.DataFrame):
p = p.with_columns(pl.col("rowid").cast(pl.Int32))

return p


@validate_types
def fit_sklearn(
formula: str, data: pl.DataFrame, engine, kwargs_engine={}, kwargs_fit={}
):
d = listwise_deletion(formula, data=data)
y, X = model_matrices(formula, d)
# formulaic returns a matrix when the response is character or categorical
if y.ndim == 2:
y = d[variables(formula)[0]]
y = np.ravel(y)
out = engine(**kwargs_engine).fit(X=X, y=y, **kwargs_fit)
out.data = d
out.formula = formula
out.formula_engine = "formulaic"
return ModelSklearn(out)
18 changes: 16 additions & 2 deletions marginaleffects/model_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import polars as pl
import patsy
from .model_abstract import ModelAbstract
from .utils import ingest
from .formulaic import listwise_deletion, model_matrices
from .utils import validate_types, ingest


class ModelStatsmodels(ModelAbstract):
Expand All @@ -22,7 +23,6 @@ def __init__(self, model):
self.formula_engine = "patsy"
self.design_info_patsy = model.model.data.design_info


def get_coef(self):
return np.array(self.model.params)

Expand Down Expand Up @@ -96,3 +96,17 @@ def get_predict(self, params, newdata: pl.DataFrame):

def get_df(self):
return self.model.df_resid


@validate_types
def fit_statsmodels(
formula: str, data: pl.DataFrame, engine, kwargs_engine={}, kwargs_fit={}
):
d = listwise_deletion(formula, data=data)
y, X = model_matrices(formula, d)
mod = engine(endog=y, exog=X, **kwargs_engine)
mod = mod.fit(**kwargs_fit)
mod.data = d
mod.formula = formula
mod.formula_engine = "formulaic"
return ModelStatsmodels(mod)
6 changes: 5 additions & 1 deletion marginaleffects/sanitize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def sanitize_model(model):
if model is None:
return model

if isinstance(model, ModelAbstract):
if (
isinstance(model, ModelAbstract)
or isinstance(model, ModelStatsmodels)
or isinstance(model, ModelSklearn)
):
return model

if is_statsmodels(model):
Expand Down

0 comments on commit 4835fc2

Please sign in to comment.