diff --git a/marginaleffects/sanitize_model.py b/marginaleffects/sanitize_model.py index e70b689..097bd34 100644 --- a/marginaleffects/sanitize_model.py +++ b/marginaleffects/sanitize_model.py @@ -2,10 +2,20 @@ def sanitize_model(model): - # TODO: other than statsmodels if model is None: return model - if not isinstance(model, ModelAbstract): - model = ModelStatsmodels(model) - return model + if isinstance(model, ModelAbstract): + return model + + try: + import statsmodels.base.wrapper as smw + + if isinstance(model, smw.ResultsWrapper): + return ModelStatsmodels(model) + except ImportError: + pass + + raise ValueError( + "Unknown model type. Try installing the 'statsmodels' package or file an issue at https://github.com/vincentarelbundock/pymarginaleffects." + ) diff --git a/poetry.lock b/poetry.lock index c7c33e6..0b1470e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -801,14 +801,14 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.0" +version = "7.0.1" description = "Read metadata from Python packages" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.0-py3-none-any.whl", hash = "sha256:d97503976bb81f40a193d41ee6570868479c69d5068651eb039c40d850c59d67"}, - {file = "importlib_metadata-7.0.0.tar.gz", hash = "sha256:7fc841f8b8332803464e5dc1c63a2e59121f46ca186c0e2e182e80bf8c1319f7"}, + {file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"}, + {file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"}, ] [package.dependencies] @@ -3245,7 +3245,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "statsmodels" version = "0.14.1" description = "Statistical computations and models for Python" -category = "main" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3596,4 +3596,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a1c39abdc1661bef83fab83072a653d7cce8ccb140b4d1c7b6c82572de3ced55" +content-hash = "c699a4df0d4856701d659b5685708f5c3c16bc58e02f324abf8708fdd0a3c57d" diff --git a/pyproject.toml b/pyproject.toml index 5c0c28b..9914677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" pandas = "^2.0.2" -statsmodels = ">0.14.0" numpy = "^1.25.0" patsy = ">0.5.0" polars = ">0.18.3" @@ -23,6 +22,7 @@ mkdocs = "^1.4.3" mkdocs-material = "^9.1.17" mkautodoc = ">=0.2.0" matplotlib = "^3.7.1" +statsmodels = ">0.14.0" typing-extensions = "^4.7.0" pytest-xdist = "^3.3.1" bandit = "^1.7.5"