Skip to content

Commit

Permalink
ingest different df types + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
artiom-matvei committed Oct 26, 2024
1 parent 1df56df commit cb4d158
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 9 deletions.
2 changes: 1 addition & 1 deletion marginaleffects/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def comparisons(
- Examples:
+ `variables = {"gear" = "pairwise", "hp" = 10}`
+ `variables = {"gear" = "sequential", "hp" = [100, 120]}`
- newdata (polars or pandas DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution).
- newdata (polars, pandas or any other ArrowstreamExportable DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution).
- comparison (str): a string specifying how pairs of predictions should be compared. See the Comparisons section below for definitions of each transformation.
- transform (function): a function specifying a transformation applied to unit-level estimates and confidence intervals just before the function returns results. Functions must accept a full column (series) of a Polars data frame and return a corresponding series of the same length. Ex:
- `transform = numpy.exp`
Expand Down
15 changes: 8 additions & 7 deletions marginaleffects/sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .datagrid import datagrid
from .estimands import estimands
from .utils import ingest, ArrowStreamExportable


def sanitize_vcov(vcov, model):
Expand Down Expand Up @@ -67,14 +68,14 @@ def sanitize_newdata(model, newdata, wts, by=[]):

else:
try:
import pandas as pd

if isinstance(newdata, pd.DataFrame):
out = pl.from_pandas(newdata)
if isinstance(newdata, ArrowStreamExportable):
out = ingest(newdata)
else:
out = newdata
except ImportError:
out = newdata
raise RuntimeError(
"Unable to ingest newdata data provided. If it is a DataFrame, make sure it implements the ArrowStreamExportable interface."
)
except Exception as e:
raise e

reserved_names = {
"rowid",
Expand Down
12 changes: 11 additions & 1 deletion marginaleffects/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import itertools

import narwhals as nw
import numpy as np
import polars as pl
from typing import Protocol, runtime_checkable


@runtime_checkable
class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...


def ingest(df: ArrowStreamExportable):
return nw.from_arrow(df, native_namespace=pl).to_native()


def sort_columns(df, by=None, newdata=None):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "Predictions, counterfactual comparisons, slopes, and hypothesis t
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"duckdb >=1.1.2",
"narwhals >=1.10.0",
"numpy >=2.0.0",
"patsy >=0.5.6",
"polars >=1.7.0",
Expand Down
70 changes: 70 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import pandas as pd
import polars as pl
import duckdb
import pyarrow as pa
import narwhals as nw
from marginaleffects.utils import ingest
from typing import Callable


def get_sample_data():
return pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
"score": [85.5, 90.0, 95.5],
}
)


sample_data = get_sample_data()


@pytest.fixture
def sample_pandas_df():
return get_sample_data()


@pytest.fixture
def sample_polars_df():
pd_df = get_sample_data()
return pl.from_pandas(pd_df)


@pytest.fixture
def sample_duckdb_df():
# Using DuckDB to create a DataFrame
con = duckdb.connect()
return con.execute("SELECT * FROM sample_data").df()


def test_ingest_pandas(sample_pandas_df):
result = ingest(sample_pandas_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = pl.from_pandas(sample_pandas_df)
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"


def test_ingest_polars(sample_polars_df):
result = ingest(sample_polars_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = sample_polars_df
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"


def test_ingest_duckdb(sample_duckdb_df):
result = ingest(sample_duckdb_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = pl.from_pandas(sample_duckdb_df)
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"
53 changes: 53 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cb4d158

Please sign in to comment.