Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ingest different df types + tests #134

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion marginaleffects/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def comparisons(
- Examples:
+ `variables = {"gear" = "pairwise", "hp" = 10}`
+ `variables = {"gear" = "sequential", "hp" = [100, 120]}`
- newdata (polars or pandas DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution).
- newdata (polars, pandas or any other ArrowstreamExportable DataFrame, or str): a data frame or a string specifying where statistics are evaluated in the predictor space. If `None`, unit-level contrasts are computed for each observed value in the original dataset (empirical distribution).
- comparison (str): a string specifying how pairs of predictions should be compared. See the Comparisons section below for definitions of each transformation.
- transform (function): a function specifying a transformation applied to unit-level estimates and confidence intervals just before the function returns results. Functions must accept a full column (series) of a Polars data frame and return a corresponding series of the same length. Ex:
- `transform = numpy.exp`
Expand Down
4 changes: 0 additions & 4 deletions marginaleffects/hypotheses_joint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import scipy.stats as stats
import polars as pl
import pandas as pd

from .sanity import sanitize_hypothesis_null
from .classes import MarginaleffectsDataFrame
Expand All @@ -10,9 +9,6 @@
def joint_hypotheses(obj, joint_index=None, joint_test="f", hypothesis=0):
assert joint_test in ["f", "chisq"], "`joint_test` must be `f` or `chisq`"

if isinstance(obj, pd.DataFrame):
obj = pl.DataFrame(obj)

Comment on lines -13 to -15
Copy link
Contributor Author

@artiom-matvei artiom-matvei Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like obj is of type model and it shouldn't ever be of type dataframe. It was introduced here #94 but not sure why. I also checked in R and obj is not a dataframe in R either.

If it were to be a dataframe all the following operations would fail so I think we can remove it. @vincentarelbundock do you agree?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems right but I won't have time to dig deep, so I'll trust you here.

# theta_hat: P x 1 vector of estimated parameters
theta_hat = obj.get_coef()

Expand Down
15 changes: 8 additions & 7 deletions marginaleffects/sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .datagrid import datagrid
from .estimands import estimands
from .utils import ingest, ArrowStreamExportable


def sanitize_vcov(vcov, model):
Expand Down Expand Up @@ -67,14 +68,14 @@ def sanitize_newdata(model, newdata, wts, by=[]):

else:
try:
import pandas as pd

if isinstance(newdata, pd.DataFrame):
out = pl.from_pandas(newdata)
if isinstance(newdata, ArrowStreamExportable):
out = ingest(newdata)
else:
out = newdata
except ImportError:
out = newdata
raise RuntimeError(
"Unable to ingest newdata data provided. If it is a DataFrame, make sure it implements the ArrowStreamExportable interface."
)
except Exception as e:
raise e

reserved_names = {
"rowid",
Expand Down
12 changes: 11 additions & 1 deletion marginaleffects/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import itertools

import narwhals as nw
import numpy as np
import polars as pl
from typing import Protocol, runtime_checkable


@runtime_checkable
class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...


def ingest(df: ArrowStreamExportable):
return nw.from_arrow(df, native_namespace=pl).to_native()


def sort_columns(df, by=None, newdata=None):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "Predictions, counterfactual comparisons, slopes, and hypothesis t
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"duckdb >=1.1.2",
"narwhals >=1.10.0",
"numpy >=2.0.0",
"patsy >=0.5.6",
"polars >=1.7.0",
Expand Down
70 changes: 70 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import pandas as pd
import polars as pl
import duckdb
import pyarrow as pa
import narwhals as nw
from marginaleffects.utils import ingest
from typing import Callable


def get_sample_data():
return pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
"score": [85.5, 90.0, 95.5],
}
)


sample_data = get_sample_data()


@pytest.fixture
def sample_pandas_df():
return get_sample_data()


@pytest.fixture
def sample_polars_df():
pd_df = get_sample_data()
return pl.from_pandas(pd_df)


@pytest.fixture
def sample_duckdb_df():
# Using DuckDB to create a DataFrame
con = duckdb.connect()
return con.execute("SELECT * FROM sample_data").df()


def test_ingest_pandas(sample_pandas_df):
result = ingest(sample_pandas_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = pl.from_pandas(sample_pandas_df)
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"


def test_ingest_polars(sample_polars_df):
result = ingest(sample_polars_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = sample_polars_df
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"


def test_ingest_duckdb(sample_duckdb_df):
result = ingest(sample_duckdb_df)
assert isinstance(result, pl.DataFrame), "Result should be a Polars DataFrame"
# Verify contents
expected = pl.from_pandas(sample_duckdb_df)
assert result.equals(
expected
), "Ingested DataFrame does not match expected Polars DataFrame"
53 changes: 53 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading