diff --git a/marginaleffects/__init__.py b/marginaleffects/__init__.py index f47dff7..ecb4655 100644 --- a/marginaleffects/__init__.py +++ b/marginaleffects/__init__.py @@ -6,6 +6,8 @@ from .plot_slopes import plot_slopes from .predictions import avg_predictions, predictions from .slopes import avg_slopes, slopes +from .model_statsmodels import fit_statsmodels +from .model_sklearn import fit_sklearn __all__ = [ "avg_comparisons", @@ -19,4 +21,6 @@ "predictions", "avg_slopes", "slopes", + "fit_statsmodels", + "fit_sklearn", ] diff --git a/marginaleffects/formulaic.py b/marginaleffects/formulaic.py index 87744e4..0a7b81c 100644 --- a/marginaleffects/formulaic.py +++ b/marginaleffects/formulaic.py @@ -1,16 +1,10 @@ import formulaic -import inspect import polars as pl -import numpy as np -from .utils import validate_types, ingest +from .utils import validate_types @validate_types def variables(formula: str): - if "~" not in formula: - raise ValueError( - "formula must contain '~' to separate dependent and independent variables" - ) tok = formulaic.parser.DefaultFormulaParser().get_tokens(formula) tok = [t for t in tok if t.kind.value == "name"] tok = [str(t) for t in tok] @@ -18,17 +12,12 @@ def variables(formula: str): @validate_types -def lwd( - formula: str | None = None, vars: list[str] | None = None, data: pl.DataFrame = None -): - if formula is not None: - vars = variables(formula) - elif formula is None and vars is None: - raise ValueError("formula or vars must be provided") +def listwise_deletion(formula: str, data: pl.DataFrame): + vars = variables(formula) return data.drop_nulls(subset=vars) -def model_matrices(formula: str, data: pl.DataFrame, formula_engine="formulaic"): +def model_matrices(formula: str, data: pl.DataFrame, formula_engine: str = "formulaic"): if formula_engine == "formulaic": endog, exog = formulaic.model_matrix(formula, data.to_pandas()) endog = endog.to_numpy() @@ -47,41 +36,4 @@ def model_matrices(formula: str, data: pl.DataFrame, formula_engine="formulaic") return None, exog -@validate_types -def design(formula: str, data: pl.DataFrame): - vars = variables(formula) - data = data.drop_nulls(subset=vars) - y, X = formulaic.model_matrix(formula, ingest(data).to_pandas()) - y = np.ravel(data[vars[0]]) # avoid matrix if LHS is a character - return y, X, data - - -def _sanity_engine(engine): - if not hasattr(engine, "fit"): - raise AttributeError("engine must have a 'fit' method") - sig = inspect.signature(engine.fit) - param_names = list(sig.parameters.keys()) - if "X" not in param_names or "y" not in param_names: - raise ValueError("engine.fit must accept parameters named 'X' and 'y'") - - -@validate_types -def fit_sklearn(formula: str, data: pl.DataFrame, engine): - y, X, data = design(formula, data) - print(engine) - out = engine.fit(X=X, y=y) - out.formula = formula - out.data = data - return out - - -@validate_types -def fit_statsmodels(formula:str, data: pl.DataFrame, engine): - y, X, data = design(formula, data) - mod = engine(endog=y, exog=X).fit() - mod.data = data - mod.formula = formula - return mod - - -__all__ = ["variables", "lwd", "design", "fit"] +__all__ = ["listwise_deletion", "model_matrices"] diff --git a/marginaleffects/model_sklearn.py b/marginaleffects/model_sklearn.py index 5dc24e9..b832da4 100644 --- a/marginaleffects/model_sklearn.py +++ b/marginaleffects/model_sklearn.py @@ -1,7 +1,8 @@ import numpy as np import warnings import polars as pl -from .utils import ingest +from .utils import validate_types, ingest +from .formulaic import listwise_deletion, model_matrices, variables from .model_abstract import ModelAbstract @@ -52,6 +53,10 @@ def get_predict(self, params, newdata: pl.DataFrame): if p.ndim == 1: p = pl.DataFrame({"rowid": range(newdata.shape[0]), "estimate": p}) + elif p.ndim == 2 and p.shape[1] == 1: + p = pl.DataFrame( + {"rowid": range(newdata.shape[0]), "estimate": np.ravel(p)} + ) elif p.ndim == 2: colnames = {f"column_{i}": v for i, v in enumerate(self.model.classes_)} p = ( @@ -70,3 +75,20 @@ def get_predict(self, params, newdata: pl.DataFrame): p = p.with_columns(pl.col("rowid").cast(pl.Int32)) return p + + +@validate_types +def fit_sklearn( + formula: str, data: pl.DataFrame, engine, kwargs_engine={}, kwargs_fit={} +): + d = listwise_deletion(formula, data=data) + y, X = model_matrices(formula, d) + # formulaic returns a matrix when the response is character or categorical + if y.ndim == 2: + y = d[variables(formula)[0]] + y = np.ravel(y) + out = engine(**kwargs_engine).fit(X=X, y=y, **kwargs_fit) + out.data = d + out.formula = formula + out.formula_engine = "formulaic" + return ModelSklearn(out) diff --git a/marginaleffects/model_statsmodels.py b/marginaleffects/model_statsmodels.py index c24f58f..b320426 100644 --- a/marginaleffects/model_statsmodels.py +++ b/marginaleffects/model_statsmodels.py @@ -3,7 +3,8 @@ import polars as pl import patsy from .model_abstract import ModelAbstract -from .utils import ingest +from .formulaic import listwise_deletion, model_matrices +from .utils import validate_types, ingest class ModelStatsmodels(ModelAbstract): @@ -22,7 +23,6 @@ def __init__(self, model): self.formula_engine = "patsy" self.design_info_patsy = model.model.data.design_info - def get_coef(self): return np.array(self.model.params) @@ -96,3 +96,17 @@ def get_predict(self, params, newdata: pl.DataFrame): def get_df(self): return self.model.df_resid + + +@validate_types +def fit_statsmodels( + formula: str, data: pl.DataFrame, engine, kwargs_engine={}, kwargs_fit={} +): + d = listwise_deletion(formula, data=data) + y, X = model_matrices(formula, d) + mod = engine(endog=y, exog=X, **kwargs_engine) + mod = mod.fit(**kwargs_fit) + mod.data = d + mod.formula = formula + mod.formula_engine = "formulaic" + return ModelStatsmodels(mod) diff --git a/marginaleffects/sanitize_model.py b/marginaleffects/sanitize_model.py index bb328de..35fcba2 100644 --- a/marginaleffects/sanitize_model.py +++ b/marginaleffects/sanitize_model.py @@ -31,7 +31,11 @@ def sanitize_model(model): if model is None: return model - if isinstance(model, ModelAbstract): + if ( + isinstance(model, ModelAbstract) + or isinstance(model, ModelStatsmodels) + or isinstance(model, ModelSklearn) + ): return model if is_statsmodels(model):