From cb4d1583bbf865e933d30d15cfd4fdb2946ef0c6 Mon Sep 17 00:00:00 2001 From: "artiom.matvei" Date: Tue, 22 Oct 2024 21:49:30 -0400 Subject: [PATCH 1/5] ingest different df types + tests --- marginaleffects/comparisons.py | 2 +- marginaleffects/sanity.py | 15 ++++---- marginaleffects/utils.py | 12 +++++- pyproject.toml | 2 + tests/test_utils.py | 70 ++++++++++++++++++++++++++++++++++ uv.lock | 53 +++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 tests/test_utils.py diff --git a/marginaleffects/comparisons.py b/marginaleffects/comparisons.py index d5b4890..f4628ef 100644 --- a/marginaleffects/comparisons.py +++ b/marginaleffects/comparisons.py @@ -79,7 +79,7 @@ def comparisons( - Examples: + `variables = {"gear" = "pairwise", "hp" = 10}` + `variables = {"gear" = "sequential", "hp" = [100, 120]}` - - newdata (polars or pandas DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution). + - newdata (polars, pandas or any other ArrowstreamExportable DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution). - comparison (str): a string specifying how pairs of predictions should be compared. See the Comparisons section below for definitions of each transformation. - transform (function): a function specifying a transformation applied to unit-level estimates and confidence intervals just before the function returns results. Functions must accept a full column (series) of a Polars data frame and return a corresponding series of the same length. Ex: - `transform = numpy.exp` diff --git a/marginaleffects/sanity.py b/marginaleffects/sanity.py index 41667e1..bbc9212 100644 --- a/marginaleffects/sanity.py +++ b/marginaleffects/sanity.py @@ -6,6 +6,7 @@ from .datagrid import datagrid from .estimands import estimands +from .utils import ingest, ArrowStreamExportable def sanitize_vcov(vcov, model): @@ -67,14 +68,14 @@ def sanitize_newdata(model, newdata, wts, by=[]): else: try: - import pandas as pd - - if isinstance(newdata, pd.DataFrame): - out = pl.from_pandas(newdata) + if isinstance(newdata, ArrowStreamExportable): + out = ingest(newdata) else: - out = newdata - except ImportError: - out = newdata + raise RuntimeError( + "Unable to ingest newdata data provided. If it is a DataFrame, make sure it implements the ArrowStreamExportable interface." + ) + except Exception as e: + raise e reserved_names = { "rowid", diff --git a/marginaleffects/utils.py b/marginaleffects/utils.py index a1be0b6..9bc2274 100644 --- a/marginaleffects/utils.py +++ b/marginaleffects/utils.py @@ -1,7 +1,17 @@ import itertools - +import narwhals as nw import numpy as np import polars as pl +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ArrowStreamExportable(Protocol): + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... + + +def ingest(df: ArrowStreamExportable): + return nw.from_arrow(df, native_namespace=pl).to_native() def sort_columns(df, by=None, newdata=None): diff --git a/pyproject.toml b/pyproject.toml index 592694f..2f56544 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,8 @@ description = "Predictions, counterfactual comparisons, slopes, and hypothesis t readme = "README.md" requires-python = ">=3.10" dependencies = [ + "duckdb >=1.1.2", + "narwhals >=1.10.0", "numpy >=2.0.0", "patsy >=0.5.6", "polars >=1.7.0", diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..dfe4256 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,70 @@ +import pytest +import pandas as pd +import polars as pl +import duckdb +import pyarrow as pa +import narwhals as nw +from marginaleffects.utils import ingest +from typing import Callable + + +def get_sample_data(): + return pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + "score": [85.5, 90.0, 95.5], + } + ) + + +sample_data = get_sample_data() + + +@pytest.fixture +def sample_pandas_df(): + return get_sample_data() + + +@pytest.fixture +def sample_polars_df(): + pd_df = get_sample_data() + return pl.from_pandas(pd_df) + + +@pytest.fixture +def sample_duckdb_df(): + # Using DuckDB to create a DataFrame + con = duckdb.connect() + return con.execute("SELECT * FROM sample_data").df() + + +def test_ingest_pandas(sample_pandas_df): + result = ingest(sample_pandas_df) + assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame" + # Verify contents + expected = pl.from_pandas(sample_pandas_df) + assert result.equals( + expected + ), "Ingested DataFrame does not match expected Polars DataFrame" + + +def test_ingest_polars(sample_polars_df): + result = ingest(sample_polars_df) + assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame" + # Verify contents + expected = sample_polars_df + assert result.equals( + expected + ), "Ingested DataFrame does not match expected Polars DataFrame" + + +def test_ingest_duckdb(sample_duckdb_df): + result = ingest(sample_duckdb_df) + assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame" + # Verify contents + expected = pl.from_pandas(sample_duckdb_df) + assert result.equals( + expected + ), "Ingested DataFrame does not match expected Polars DataFrame" diff --git a/uv.lock b/uv.lock index 300a4d1..f0dbd29 100644 --- a/uv.lock +++ b/uv.lock @@ -104,6 +104,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 }, ] +[[package]] +name = "duckdb" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/0c/6c6093fba60e5b8ac2abaee9b6a2b379e77419fe6102a36aa383944610fe/duckdb-1.1.2.tar.gz", hash = "sha256:c8232861dc8ec6daa29067056d5a0e5789919f2ab22ab792787616d7cd52f02a", size = 12237077 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/1e/4a7073909ed10cc6fdc5a101267d09e52b57054af137c63fb7040536e3ae/duckdb-1.1.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:91e7f99cf5cab1d26f92cb014429153497d805e79689baa44f4c4585a8cb243f", size = 15464881 }, + { url = "https://files.pythonhosted.org/packages/5d/f4/0c94ed5635b348f8f8f3a315d2139640239f3d9cca87768fe7591fecdd0b/duckdb-1.1.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:0107de622fe208142a1108263a03c43956048dcc99be3702d8e5d2aeaf99554c", size = 32301705 }, + { url = "https://files.pythonhosted.org/packages/70/39/2fc821b1b587a6589d18c1c07665ff5b08cd2497ea39db2f047d13520bd4/duckdb-1.1.2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:8a09610f780857677725897856f8cdf3cafd8a991f871e6cb8ba88b2dbc8d737", size = 16924272 }, + { url = "https://files.pythonhosted.org/packages/8d/09/913d4e5334d62ec57d0261589a775bc0870f90e919d26c511c6df3ac4067/duckdb-1.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0f0ddac0482f0f3fece54d720d13819e82ae26c01a939ffa66a87be53f7f665", size = 18486251 }, + { url = "https://files.pythonhosted.org/packages/dd/33/1a38837b4b0fc1a33ff5e5e623cea053f24a33fbedacacd04ab2bae4d615/duckdb-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84103373e818758dfa361d27781d0f096553843c5ffb9193260a0786c5248270", size = 20140537 }, + { url = "https://files.pythonhosted.org/packages/36/ec/1339d8c3431c3c77d7f3f9272a66128e3535a5bff03fd47c40f179a1c3f2/duckdb-1.1.2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bfdfd23e2bf58014ad0673973bd0ed88cd048dfe8e82420814a71d7d52ef2288", size = 18283601 }, + { url = "https://files.pythonhosted.org/packages/4d/ae/90ecc2f4391a96851dc7e69e442dd5f69501394088006d00209f4785d0fb/duckdb-1.1.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:25889e6e29b87047b1dd56385ac08156e4713c59326cc6fff89657d01b2c417b", size = 21598783 }, + { url = "https://files.pythonhosted.org/packages/54/e0/e611af7f72c6fbe0ff9d0440134139b94f1ed144a3d4d7841b545b3f205e/duckdb-1.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:312570fa5277c3079de18388b86c2d87cbe1044838bb152b235c0227581d5d42", size = 10950443 }, + { url = "https://files.pythonhosted.org/packages/96/23/fe7cc36ac4db1fd3b6433a698096ea14128d8e76b633d68425a062d35ac7/duckdb-1.1.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:568439ea4fce8cb72ec1f767cd510686a9e7e29a011fc7c56d990059a6e94e48", size = 15467057 }, + { url = "https://files.pythonhosted.org/packages/d2/ca/220fccba2220d62c4d64264cc960fcfde083e71b98721144b889b2aab914/duckdb-1.1.2-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:74974f2d7210623a5d61b1fb0cb589c6e5ffcbf7dbb757a04c5ba24adcfc8cac", size = 32308152 }, + { url = "https://files.pythonhosted.org/packages/c0/9a/4a79edab02cd5070b3fe581385c703cd591e26a74c4b0222225aec499d74/duckdb-1.1.2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:e26422a3358c816d764639070945b73eef55d1b4df990989e3492c85ef725c21", size = 16926668 }, + { url = "https://files.pythonhosted.org/packages/21/98/d785e3a845ee06c2138baaa1699486f31aec467cda5a2e1d57b70a8d185b/duckdb-1.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87e972bd452eeeab197fe39dcaeecdb7c264b1f75a0ee67e532e235fe45b84df", size = 18490749 }, + { url = "https://files.pythonhosted.org/packages/b9/91/f09147562d7d70ea985c632321daaa86746088fa68ca9d3e42cddeed60ef/duckdb-1.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a6b73e70b73c8df85da383f6e557c03cad5c877868b9a7e41715761e8166c1e", size = 20144525 }, + { url = "https://files.pythonhosted.org/packages/0e/a4/da374d7c8ee777ccb70f751492e0a459123313e4c9784eedc9aea8676dc4/duckdb-1.1.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:623cb1952466aae5907af84107bcdec25a5ca021a8b6441e961f41edc724f6f2", size = 18286202 }, + { url = "https://files.pythonhosted.org/packages/67/48/e62a17c7ad31ec6af2fbeb026df8a27c8839883dccbf1d8ea9c8e5c89db3/duckdb-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9fc0b550f96901fa7e76dc70a13f6477ad3e18ef1cb21d414c3a5569de3f27e", size = 21601216 }, + { url = "https://files.pythonhosted.org/packages/03/97/7678626c03317ff3004e066223eef8304adf534e5bf77388c9cd8560f637/duckdb-1.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:181edb1973bd8f493bcb6ecfa035f1a592dff4667758592f300619012ba251c0", size = 10951571 }, + { url = "https://files.pythonhosted.org/packages/8d/c6/946a714a4aa285aeeec74ac827eeb37c9b29102c2c1c27a1a98cb2cc7c9d/duckdb-1.1.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:83372b1b411086cac01ab2071122772fa66170b1b41ddbc37527464066083668", size = 15471960 }, + { url = "https://files.pythonhosted.org/packages/97/a8/e346d35d51fef06018485386a02ba68e1777bdecca06ea3d1251559af35f/duckdb-1.1.2-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:db37441deddfee6ac35a0c742d2f9e90e4e50b9e76d586a060d122b8fc56dada", size = 32343109 }, + { url = "https://files.pythonhosted.org/packages/8c/00/6ec504a8a41d296c0b2cccdde730ff974a9620b275917de3746b31f46866/duckdb-1.1.2-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:19142a77e72874aeaa6fda30aeb13612c6de5e8c60fbcc3392cea6ef0694eeaf", size = 16947327 }, + { url = "https://files.pythonhosted.org/packages/0b/b6/f60be01d29b87d4497dd9eed0e82a89859ac5a772030645a9b2728ab73eb/duckdb-1.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099d99dd48d6e4682a3dd6233ceab73d977ebe1a87afaac54cf77c844e24514a", size = 18485234 }, + { url = "https://files.pythonhosted.org/packages/c2/28/ecc5f8ab0e7b2b00e2b8a3385f0a3d0bcdecdfef44719fcdc32e744ac6f1/duckdb-1.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be86e586ca7af7e807f72479a2b8d0983565360b19dbda4ef8a9d7b3909b8e2c", size = 20141584 }, + { url = "https://files.pythonhosted.org/packages/48/e1/d6e4abbdf20b498f78ddb4e66ef2689318bac4f35a3c71a461e1a22ea3ca/duckdb-1.1.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:578e0953e4d8ba8da0cd69fb2930c45f51ce47d213b77d8a4cd461f9c0960b87", size = 18286704 }, + { url = "https://files.pythonhosted.org/packages/7b/ec/4beb866ced0e0db8e998530c0df841881fe7c26e4cc52baf909911852529/duckdb-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:72b5eb5762c1a5e68849c7143f3b3747a9f15c040e34e41559f233a1569ad16f", size = 21612773 }, + { url = "https://files.pythonhosted.org/packages/19/30/6c1ad1c3db49be118a42565b524b359aabdce2e1ea56138e0252a6f69f7a/duckdb-1.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:9b4c6b6a08180261d98330d97355503961a25ca31cd9ef296e0681f7895b4a2c", size = 10953379 }, + { url = "https://files.pythonhosted.org/packages/11/15/0ea64233ef3a9eb2cf1b276c2dd5f29e80f35f7a903ebc168b4c6a099bbf/duckdb-1.1.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:695dcbc561374b126e86659709feadf883c9969ed718e94713edd4ba15d16619", size = 15472119 }, + { url = "https://files.pythonhosted.org/packages/4d/2a/aea827229c4d2da5896f46619bc6ad3cf4c8c7e72ffc0c4f68ffc73989e1/duckdb-1.1.2-cp313-cp313-macosx_12_0_universal2.whl", hash = "sha256:ada29be1e889f486c6cf1f6dffd15463e748faf361f33996f2e862779edc24a9", size = 32343664 }, + { url = "https://files.pythonhosted.org/packages/06/b4/ef938871fb90762702f7120b572d16a1f71ab8d1bad5cbd171d9eb8820a5/duckdb-1.1.2-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:6ca722738fa9eb6218619740631de29acfdd132de6f6a6350fee5e291c2f6117", size = 16947866 }, + { url = "https://files.pythonhosted.org/packages/8e/eb/54f4f95fb709c456e2b90c2a68e6c3b55f2b45000c1acd01fe5669f5f23c/duckdb-1.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c796d33f1e5a0c8c570d22da0c0b1db8578687e427029e1ce2c8ce3f9fffa6a3", size = 18486824 }, + { url = "https://files.pythonhosted.org/packages/f1/7b/b5194ec90ff8070fe24721542d78b6293e9087df2cc14fa7552077806f1b/duckdb-1.1.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5c0996988a70dd3bc8111d9b9aeab7e38ed1999a52607c5f1b528e362b4dd1c", size = 20142187 }, + { url = "https://files.pythonhosted.org/packages/69/cb/b36b871df43d39ab14941f5936205d15a5d302295daa9dab2c3c7aae3ccc/duckdb-1.1.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c37b039f6d6fed14d89450f5ccf54922b3304192d7412e12d6cc8d9e757f7a2", size = 18287193 }, + { url = "https://files.pythonhosted.org/packages/e6/26/0d3ad1790c543923027ea20aff987c66049163331d77326d2953a03a3b52/duckdb-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8c766b87f675c76d6d17103bf6fb9fb1a9e2fcb3d9b25c28bbc634bde31223e", size = 21611680 }, + { url = "https://files.pythonhosted.org/packages/c8/2f/b81d855287b8bb44da1454a0d6e154dcd0e40f1c3654d120002cbde31479/duckdb-1.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:e3e6300b7ccaf64b609f4f0780a6e1d25ab8cf34cceed46e62c35b6c4c5cb63b", size = 10953563 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -374,6 +414,8 @@ name = "marginaleffects" version = "0.0.13.1" source = { virtual = "." } dependencies = [ + { name = "duckdb" }, + { name = "narwhals" }, { name = "numpy" }, { name = "patsy" }, { name = "plotnine" }, @@ -400,7 +442,9 @@ dev = [ [package.metadata] requires-dist = [ + { name = "duckdb" }, { name = "matplotlib", marker = "extra == 'test'", specifier = ">=3.7.1" }, + { name = "narwhals", specifier = ">=1.10.0" }, { name = "numpy", specifier = ">=2.0.0" }, { name = "pandas", marker = "extra == 'test'", specifier = ">=2.2.2" }, { name = "patsy", specifier = ">=0.5.6" }, @@ -483,6 +527,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/11/f3777ad46c5d92e3ead121c22ea45fafb6c3b2c1edca0c0c6494969c125c/mizani-0.11.4-py3-none-any.whl", hash = "sha256:5b6271dc3da2c88694dca2e0e0a7e1879f0e2fb046c789776f54d090a5243735", size = 127428 }, ] +[[package]] +name = "narwhals" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/43/5f03a6c1976f8f8053f98d33892be6534fcde908dd1827fe69eb3b9ad90f/narwhals-1.10.0.tar.gz", hash = "sha256:a380e64110c3169c4b0b592c5b64ae6dc4cce76e9d3c56edc608a8ae5994cfc1", size = 161509 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/45/6bb08a5e5ac2d9a95f4116567d0aafb1e5f86c4bbcf982e209e5ae872b3d/narwhals-1.10.0-py3-none-any.whl", hash = "sha256:c83a378960651c391e5f3d68af3a821eda74c9713073518fe0c39aefc5ad8f8e", size = 193600 }, +] + [[package]] name = "numba" version = "0.60.0" From 65f9d19274025ea72500a2b8c2e28af031dd418f Mon Sep 17 00:00:00 2001 From: "artiom.matvei" Date: Sat, 26 Oct 2024 20:44:09 -0400 Subject: [PATCH 2/5] removed joint_hypotheses option for obj being of type dataframe --- marginaleffects/hypotheses_joint.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/marginaleffects/hypotheses_joint.py b/marginaleffects/hypotheses_joint.py index bd76dc2..e3455c4 100644 --- a/marginaleffects/hypotheses_joint.py +++ b/marginaleffects/hypotheses_joint.py @@ -1,7 +1,6 @@ import numpy as np import scipy.stats as stats import polars as pl -import pandas as pd from .sanity import sanitize_hypothesis_null from .classes import MarginaleffectsDataFrame @@ -10,9 +9,6 @@ def joint_hypotheses(obj, joint_index=None, joint_test="f", hypothesis=0): assert joint_test in ["f", "chisq"], "`joint_test` must be `f` or `chisq`" - if isinstance(obj, pd.DataFrame): - obj = pl.DataFrame(obj) - # theta_hat: P x 1 vector of estimated parameters theta_hat = obj.get_coef() From f8840a77ad6f4521105b732e7c560c1c40d3a19d Mon Sep 17 00:00:00 2001 From: "artiom.matvei" Date: Sat, 26 Oct 2024 21:01:52 -0400 Subject: [PATCH 3/5] updated pyproject.toml to reduce dependencies duckdv, pyarrow, scipy --- pyproject.toml | 12 ++++++------ uv.lock | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2f56544..c13c9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,23 +5,23 @@ description = "Predictions, counterfactual comparisons, slopes, and hypothesis t readme = "README.md" requires-python = ">=3.10" dependencies = [ - "duckdb >=1.1.2", "narwhals >=1.10.0", "numpy >=2.0.0", "patsy >=0.5.6", "polars >=1.7.0", - "pyarrow >=17.0.0", - "scipy >=1.14.1", "plotnine >=0.13.6", + "scipy >=1.14.1", ] [project.optional-dependencies] test = [ - "pandas >=2.2.2", + "duckdb >=1.1.2", "matplotlib >=3.7.1", - "typing-extensions >=4.7.0", - "statsmodels >=0.14.0", + "pandas >=2.2.2", + "pyarrow >=17.0.0", "pyfixest >=0.24.2", + "statsmodels >=0.14.0", + "typing-extensions >=4.7.0", ] [tool.uv] diff --git a/uv.lock b/uv.lock index f0dbd29..560594c 100644 --- a/uv.lock +++ b/uv.lock @@ -414,20 +414,20 @@ name = "marginaleffects" version = "0.0.13.1" source = { virtual = "." } dependencies = [ - { name = "duckdb" }, { name = "narwhals" }, { name = "numpy" }, { name = "patsy" }, { name = "plotnine" }, { name = "polars" }, - { name = "pyarrow" }, { name = "scipy" }, ] [package.optional-dependencies] test = [ + { name = "duckdb" }, { name = "matplotlib" }, { name = "pandas" }, + { name = "pyarrow" }, { name = "pyfixest" }, { name = "statsmodels" }, { name = "typing-extensions" }, @@ -442,7 +442,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "duckdb" }, + { name = "duckdb", marker = "extra == 'test'", specifier = ">=1.1.2" }, { name = "matplotlib", marker = "extra == 'test'", specifier = ">=3.7.1" }, { name = "narwhals", specifier = ">=1.10.0" }, { name = "numpy", specifier = ">=2.0.0" }, @@ -450,7 +450,7 @@ requires-dist = [ { name = "patsy", specifier = ">=0.5.6" }, { name = "plotnine", specifier = ">=0.13.6" }, { name = "polars", specifier = ">=1.7.0" }, - { name = "pyarrow", specifier = ">=17.0.0" }, + { name = "pyarrow", marker = "extra == 'test'", specifier = ">=17.0.0" }, { name = "pyfixest", marker = "extra == 'test'", specifier = ">=0.24.2" }, { name = "scipy", specifier = ">=1.14.1" }, { name = "statsmodels", marker = "extra == 'test'", specifier = ">=0.14.0" }, From 0eb894b1f8addd61d679d9aa8035976efc6d7f5d Mon Sep 17 00:00:00 2001 From: "artiom.matvei" Date: Sat, 26 Oct 2024 21:32:07 -0400 Subject: [PATCH 4/5] comparisons documentation homogenization --- marginaleffects/comparisons.py | 121 ++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 56 deletions(-) diff --git a/marginaleffects/comparisons.py b/marginaleffects/comparisons.py index f4628ef..f5b50d0 100644 --- a/marginaleffects/comparisons.py +++ b/marginaleffects/comparisons.py @@ -40,8 +40,69 @@ def comparisons( ): """ `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, etc. - - # Usage: + + Parameters + ---------- + model : object + Model object fitted using the `statsmodels` formula API. + variables : str, list, dictionary + - a string, list of strings, or dictionary of variables to compute comparisons for. If `None`, comparisons are computed for all regressors in the model object. Acceptable values depend on the variable type. See the examples below. + - Dictionary: keys identify the subset of variables of interest, and values define the type of contrast to compute. Acceptable values depend on the variable type: + - Categorical variables: + * "reference": Each factor level is compared to the factor reference (base) level + * "all": All combinations of observed levels + * "sequential": Each factor level is compared to the previous factor level + * "pairwise": Each factor level is compared to all other levels + * "minmax": The highest and lowest levels of a factor. + * "revpairwise", "revreference", "revsequential": inverse of the corresponding hypotheses. + * Vector of length 2 with the two values to compare. + - Boolean variables: + * `None`: contrast between True and False + - Numeric variables: + * Numeric of length 1: Contrast for a gap of `x`, computed at the observed value plus and minus `x / 2`. For example, estimating a `+1` contrast compares adjusted predictions when the regressor is equal to its observed value minus 0.5 and its observed value plus 0.5. + * Numeric of length equal to the number of rows in `newdata`: Same as above, but the contrast can be customized for each row of `newdata`. + * Numeric vector of length 2: Contrast between the 2nd element and the 1st element of the `x` vector. + * Data frame with the same number of rows as `newdata`, with two columns of "low" and "high" values to compare. + * Function which accepts a numeric vector and returns a data frame with two columns of "low" and "high" values to compare. See examples below. + * "iqr": Contrast across the interquartile range of the regressor. + * "sd": Contrast across one standard deviation around the regressor mean. + * "2sd": Contrast across two standard deviations around the regressor mean. + * "minmax": Contrast between the maximum and the minimum values of the regressor. + - Examples: + + `variables = {"gear" = "pairwise", "hp" = 10}` + + `variables = {"gear" = "sequential", "hp" = [100, 120]}` + newdata : polars or pandas DataFrame, or str + Data frame or string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution). + comparison : str + String specifying how pairs of predictions should be compared. See the Comparisons section below for definitions of each transformation. + transform : function + Function specifying a transformation applied to unit-level estimates and confidence intervals just before the function returns results. Functions must accept a full column (series) of a Polars data frame and return a corresponding series of the same length. Ex: + - `transform = numpy.exp` + - `transform = lambda x: x.exp()` + - `transform = lambda x: x.map_elements()` + equivalence : list + List of 2 numeric values specifying the bounds used for the two-one-sided test (TOST) of equivalence, and for the non-inferiority and non-superiority tests. See the Details section below. + by : bool, str + Logical value, list of column names in `newdata`. If `True`, estimates are aggregated for each term. + hypothesis : str, numpy array + String specifying a numeric value specifying the null hypothesis used for computing p-values. + conf_level : float + Numeric value specifying the confidence level for the confidence intervals. Default is 0.95. + Returns + ------- + out : DataFrame + The functions return a data.frame with the following columns: + - term: the name of the variable. + - contrast: the comparison method used. + - estimate: the estimated contrast, difference, ratio, or other transformation between pairs of predictions. + - std_error: the standard error of the estimate. + - statistic: the test statistic (estimate / std.error). + - p_value: the p-value of the test. + - s_value: Shannon transform of the p value. + - conf_low: the lower confidence interval bound. + - conf_high: the upper confidence interval bound. + Examples + -------- comparisons(model, variables = NULL, newdata = NULL, comparison = "difference", transform = NULL, equivalence = NULL, by = FALSE, cross = FALSE, @@ -51,60 +112,8 @@ def comparisons( transform = NULL, equivalence = NULL, by = FALSE, cross = FALSE, type = "response", hypothesis = 0, conf.level = 0.95, ...) - # Args: - - - model (object): a model object fitted using the `statsmodels` formula API. - - variables (str, list, or dictionary): a string, list of strings, or dictionary of variables to compute comparisons for. If `None`, comparisons are computed for all regressors in the model object. Acceptable values depend on the variable type. See the examples below. - * Dictionary: keys identify the subset of variables of interest, and values define the type of contrast to compute. Acceptable values depend on the variable type: - - Categorical variables: - * "reference": Each factor level is compared to the factor reference (base) level - * "all": All combinations of observed levels - * "sequential": Each factor level is compared to the previous factor level - * "pairwise": Each factor level is compared to all other levels - * "minmax": The highest and lowest levels of a factor. - * "revpairwise", "revreference", "revsequential": inverse of the corresponding hypotheses. - * Vector of length 2 with the two values to compare. - - Boolean variables: - * `None`: contrast between True and False - - Numeric variables: - * Numeric of length 1: Contrast for a gap of `x`, computed at the observed value plus and minus `x / 2`. For example, estimating a `+1` contrast compares adjusted predictions when the regressor is equal to its observed value minus 0.5 and its observed value plus 0.5. - * Numeric of length equal to the number of rows in `newdata`: Same as above, but the contrast can be customized for each row of `newdata`. - * Numeric vector of length 2: Contrast between the 2nd element and the 1st element of the `x` vector. - * Data frame with the same number of rows as `newdata`, with two columns of "low" and "high" values to compare. - * Function which accepts a numeric vector and returns a data frame with two columns of "low" and "high" values to compare. See examples below. - * "iqr": Contrast across the interquartile range of the regressor. - * "sd": Contrast across one standard deviation around the regressor mean. - * "2sd": Contrast across two standard deviations around the regressor mean. - * "minmax": Contrast between the maximum and the minimum values of the regressor. - - Examples: - + `variables = {"gear" = "pairwise", "hp" = 10}` - + `variables = {"gear" = "sequential", "hp" = [100, 120]}` - - newdata (polars, pandas or any other ArrowstreamExportable DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution). - - comparison (str): a string specifying how pairs of predictions should be compared. See the Comparisons section below for definitions of each transformation. - - transform (function): a function specifying a transformation applied to unit-level estimates and confidence intervals just before the function returns results. Functions must accept a full column (series) of a Polars data frame and return a corresponding series of the same length. Ex: - - `transform = numpy.exp` - - `transform = lambda x: x.exp()` - - `transform = lambda x: x.map_elements()` - - equivalence (list): a list of 2 numeric values specifying the bounds used for the two-one-sided test (TOST) of equivalence, and for the non-inferiority and non-superiority tests. See the Details section below. - - by (bool, str): a logical value, a list of column names in `newdata`. If `True`, estimates are aggregated for each term. - - hypothesis (str, numpy array): a string specifying a numeric value specifying the null hypothesis used for computing p-values. - - conf.level (float): a numeric value specifying the confidence level for the confidence intervals. Default is 0.95. - - # Returns: - - The functions return a data.frame with the following columns: - - - term: the name of the variable. - - contrast: the comparison method used. - - estimate: the estimated contrast, difference, ratio, or other transformation between pairs of predictions. - - std_error: the standard error of the estimate. - - statistic: the test statistic (estimate / std.error). - - p_value: the p-value of the test. - - s_value: Shannon transform of the p value. - - conf_low: the lower confidence interval bound. - - conf_high: the upper confidence interval bound. - - # Details: + Details + ------- The `equivalence` argument specifies the bounds used for the two-one-sided test (TOST) of equivalence, and for the non-inferiority and non-superiority tests. The first element specifies the lower bound, and the second element specifies the upper bound. If `None`, equivalence tests are not performed. """ From a04d140b4da659541ee62a8b58a1c1825767a025 Mon Sep 17 00:00:00 2001 From: "artiom.matvei" Date: Sat, 26 Oct 2024 21:38:37 -0400 Subject: [PATCH 5/5] removed unused test data file --- tests/r/test_predictions_newdata_balanced_01.csv | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 tests/r/test_predictions_newdata_balanced_01.csv diff --git a/tests/r/test_predictions_newdata_balanced_01.csv b/tests/r/test_predictions_newdata_balanced_01.csv deleted file mode 100644 index e53e987..0000000 --- a/tests/r/test_predictions_newdata_balanced_01.csv +++ /dev/null @@ -1,10 +0,0 @@ -rowid,estimate,std_error,statistic,p_value,s_value,conf_low,conf_high,flipper_length_mm,species,bill_length_mm,island,body_mass_g -0,4312.079454483751,106.62161470412332,40.44282640485084,0.0,inf,4103.104929690163,4521.053979277339,200.96696696696696,Adelie,43.99279279279279,Biscoe,4207.057057057057 -1,4310.295242940644,105.76163416585023,40.75480940641906,0.0,inf,4103.006249029477,4517.5842368518115,200.96696696696696,Adelie,43.99279279279279,Dream,4207.057057057057 -2,4244.471757854611,100.48707206266751,42.238983291379,0.0,inf,4047.5207156999013,4441.42280000932,200.96696696696696,Adelie,43.99279279279279,Torgersen,4207.057057057057 -3,3604.5409352779025,152.29379548926335,23.668337398105106,0.0,inf,3306.0505810500376,3903.0312895057673,200.96696696696696,Chinstrap,43.99279279279279,Biscoe,4207.057057057057 -4,3602.756723734794,136.31056743913638,26.430501988361613,0.0,inf,3335.5929208418684,3869.9205266277195,200.96696696696696,Chinstrap,43.99279279279279,Dream,4207.057057057057 -5,3536.9332386487586,151.76843639357855,23.304801200405652,0.0,inf,3239.4725693273867,3834.3939079701304,200.96696696696696,Chinstrap,43.99279279279279,Torgersen,4207.057057057057 -6,4057.4186754644143,122.42114140927929,33.14312077764102,0.0,inf,3817.477647355942,4297.359703572887,200.96696696696696,Gentoo,43.99279279279279,Biscoe,4207.057057057057 -7,4055.6344639213094,139.9986390090601,28.96909921859202,0.0,inf,3781.242173578927,4330.026754263691,200.96696696696696,Gentoo,43.99279279279279,Dream,4207.057057057057 -8,3989.8109788352776,141.51067786729124,28.19441641412335,0.0,inf,3712.4551467875376,4267.166810883018,200.96696696696696,Gentoo,43.99279279279279,Torgersen,4207.057057057057