Skip to content

Commit

Permalink
column order in print with by or datagrid()
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Oct 1, 2023
1 parent f8e7382 commit 54d08d6
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 14 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# dev

* `hypothesis` accepts a float or integer to specify a different null hypothesis.
* Better column order in printout when using `datagrid()` or `by`

# 0.0.5

Expand Down
9 changes: 8 additions & 1 deletion marginaleffects/classes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import polars as pl

class MarginaleffectsDataFrame(pl.DataFrame):
def __init__(self, data=None, by=None, conf_level=0.95):
def __init__(self, data=None, by=None, conf_level=0.95, newdata=None):
if isinstance(data, pl.DataFrame):
self._df = data._df
self.by = by
self.conf_level = conf_level
if hasattr(newdata, "datagrid_explicit"):
self.datagrid_explicit = newdata.datagrid_explicit
else:
self.datagrid_explicit = []
return
super().__init__(data)

Expand Down Expand Up @@ -42,6 +46,9 @@ def __str__(self):
raise ValueError("by must be None or a string or a list of strings")
else:
valid = list(mapping.keys())

valid = self.datagrid_explicit + valid

valid = [x for x in valid if x in self.columns]
mapping = {key: mapping[key] for key in mapping if key in valid}
tmp = self.select(valid).rename(mapping)
Expand Down
4 changes: 2 additions & 2 deletions marginaleffects/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def outer(x):

out = get_transform(out, transform=transform)
out = get_equivalence(out, equivalence=equivalence, df=np.inf)
out = sort_columns(out, by=by)
out = sort_columns(out, by=by, newdata=newdata)

out = MarginaleffectsDataFrame(out, by=by, conf_level=conf_level)
out = MarginaleffectsDataFrame(out, by=by, conf_level=conf_level, newdata=newdata)
return out


Expand Down
4 changes: 4 additions & 0 deletions marginaleffects/datagrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def datagrid(

out = reduce(lambda x, y: x.join(y, how="cross"), out.values())

out.datagrid_explicit = list(kwargs.keys())

return out


Expand Down Expand Up @@ -132,4 +134,6 @@ def datagridcf(model=None, newdata=None, **kwargs):
# Create rowid and rowidcf
result = result.with_columns(pl.Series(range(result.shape[0])).alias("rowidcf"))

result.datagrid_explicit = list(kwargs.keys())

return result
26 changes: 21 additions & 5 deletions marginaleffects/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,43 @@ def predictions(
y, exog = patsy.dmatrices(model.model.formula, newdata.to_pandas())

# estimands
def fun(x):
def inner(x):
out = get_predictions(model, np.array(x), exog)

if out.shape[0] == newdata.shape[0]:
cols = [x for x in newdata.columns if x not in out.columns]
out = pl.concat([out, newdata.select(cols)], how="horizontal")

# group
elif "group" in out.columns:
meta = newdata.join(out.select("group").unique(), how="cross")
cols = [x for x in meta.columns if x in out.columns]
out = meta.join(out, on=cols, how="left")

# not sure what happens here
else:
raise ValueError("Something went wrong")

out = get_by(model, out, newdata=newdata, by=by, wts=wts)
out = get_hypothesis(out, hypothesis=hypothesis)
return out

out = fun(model.params)
out = inner(model.params)

if vcov is not None:
J = get_jacobian(fun, model.params)
J = get_jacobian(inner, model.params)
se = get_se(J, V)
out = out.with_columns(pl.Series(se).alias("std_error"))
out = get_z_p_ci(out, model, conf_level=conf_level, hypothesis_null=hypothesis_null)
out = get_transform(out, transform=transform)
out = get_equivalence(out, equivalence=equivalence)
out = sort_columns(out, by=by)
out = sort_columns(out, by=by, newdata=newdata)

# unpad
if "rowid" in out.columns and pad.shape[0] > 0:
out = out[:-pad.shape[0]:]

out = MarginaleffectsDataFrame(out, by=by, conf_level=conf_level)
out = MarginaleffectsDataFrame(out, by=by, conf_level=conf_level, newdata=newdata)
return out


Expand Down
12 changes: 9 additions & 3 deletions marginaleffects/sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def sanitize_by(by):
return by


def sanitize_newdata(model, newdata, wts, by = []):
def sanitize_newdata(model, newdata, wts, by=[]):
modeldata = get_modeldata(model)

if newdata is None:
Expand All @@ -63,10 +63,14 @@ def sanitize_newdata(model, newdata, wts, by = []):

elif isinstance(newdata, pd.DataFrame):
out = pl.from_pandas(newdata)

else:
out = newdata

datagrid_explicit = None
if isinstance(out, pl.DataFrame) and hasattr(out, "datagrid_explicit"):
datagrid_explicit = out.datagrid_explicit

if isinstance(by, list) and len(by) > 0:
by = [x for x in by if x in out.columns]
if len(by) > 0:
Expand All @@ -90,8 +94,10 @@ def sanitize_newdata(model, newdata, wts, by = []):
if any([isinstance(out[x], pl.Categorical) for x in out.columns]):
raise ValueError("Categorical type columns are not supported in `newdata`.")

return out
if datagrid_explicit is not None:
out.datagrid_explicit = datagrid_explicit

return out

def sanitize_comparison(comparison, by, wts=None):
out = comparison
Expand Down
7 changes: 6 additions & 1 deletion marginaleffects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_modeldata(fit):
return out


def sort_columns(df, by=None):
def sort_columns(df, by=None, newdata=None):
cols = [
"rowid",
"group",
Expand All @@ -30,11 +30,16 @@ def sort_columns(df, by=None):
"conf_low",
"conf_high",
] + df.columns

if by is not None:
if isinstance(by, list):
cols = by + cols
else:
cols = [by] + cols

if isinstance(newdata, pl.DataFrame) and hasattr(newdata, "datagrid_explicit"):
cols = newdata.datagrid_explicit + cols

cols = [x for x in cols if x in df.columns]
cols_unique = []
for item in cols:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "marginaleffects"
version = "0.0.5.9001"
version = "0.0.5.9002"
description = ""
authors = ["Vincent Arel-Bundock <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytest import approx
import polars as pl
from marginaleffects import *
from .utilities import *
# from .utilities import *
import statsmodels.formula.api as smf

Guerry = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA").drop_nulls()
Expand Down

0 comments on commit 54d08d6

Please sign in to comment.