Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue #100 fixed row labels names for hypothesis= revpairwise, pairwise, revreference, reference, revsequential, sequential #137

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion marginaleffects/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def applyfun(x, by, wts=None):
function=lambda x: applyfun(x, by=by, wts=wts)
)

tmp = get_hypothesis(tmp, hypothesis=hypothesis)
tmp = get_hypothesis(tmp, hypothesis=hypothesis, by=by)

return tmp

Expand Down
75 changes: 52 additions & 23 deletions marginaleffects/hypothesis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re

from itertools import compress
import numpy as np
import polars as pl

Expand Down Expand Up @@ -46,7 +46,7 @@ def eval_string_function(vec, hypothesis, rowlabels):


# function extracts the estimate column from a data frame and sets it to x. If `hypothesis` argument is a numpy array, it feeds it directly to lincome_multiply. If lincome is a string, it checks if the string is valid, and then calls the corresponding function.
def get_hypothesis(x, hypothesis):
def get_hypothesis(x, hypothesis, by=None):
msg = f"Invalid hypothesis argument: {hypothesis}. Valid arguments are: 'reference', 'revreference', 'sequential', 'revsequential', 'pairwise', 'revpairwise' or a numpy array or a float."

if hypothesis is None or isinstance(hypothesis, (int, float)):
Expand All @@ -59,17 +59,17 @@ def get_hypothesis(x, hypothesis):
out = eval_string_hypothesis(x, hypothesis, lab=hypothesis)
elif isinstance(hypothesis, str):
if hypothesis == "reference":
hypmat = lincom_reference(x)
hypmat = lincom_reference(x, by)
elif hypothesis == "revreference":
hypmat = lincom_revreference(x)
hypmat = lincom_revreference(x, by)
elif hypothesis == "sequential":
hypmat = lincom_sequential(x)
hypmat = lincom_sequential(x, by)
elif hypothesis == "revsequential":
hypmat = lincom_revsequential(x)
hypmat = lincom_revsequential(x, by)
elif hypothesis == "pairwise":
hypmat = lincom_pairwise(x)
hypmat = lincom_pairwise(x, by)
elif hypothesis == "revpairwise":
hypmat = lincom_revpairwise(x)
hypmat = lincom_revpairwise(x, by)
else:
raise ValueError(msg)
out = lincom_multiply(x, hypmat.to_numpy())
Expand All @@ -86,15 +86,44 @@ def lincom_multiply(x, lincom):
return out


# TODO: improve labeling
def get_hypothesis_row_labels(x):
return ["i" for i in range(len(x))]
def get_hypothesis_row_labels(x, by=None):
pattern = re.compile(r"^(term|by|group|value|contrast|contrast_)$")
lab = [col for col in x.columns if pattern.match(col)]

# Step 2: Filter columns with more than one unique value
lab = [col for col in lab if len(x[col].unique()) > 1]

# Step 3: Include additional columns from "by" if provided
if by is not None:
if isinstance(by, str):
by = [by]
lab = [e for e in list(set(lab) | set(by)) if e != "group"]

# Step 4: If no columns left, return default
if len(lab) == 0:
return [f"{i}" for i in range(len(x))]

# Step 5: Create a sub-dataframe with selected columns
lab_df = x[lab]
idx = [x[col].n_unique() > 1 for col in lab_df.columns]

# Step 6: Create row labels by concatenating values
if any(idx):
lab_df = lab_df.select(list(compress(lab_df.columns, idx)))
lab = lab_df.select(
pl.concat_str(lab_df.columns, separator=", ").alias("concatenated")
)["concatenated"].to_list()

# Step 7: Wrap labels containing "-" in parentheses
lab = [f"({label})" if "-" in label else label for label in lab]

return lab


def lincom_revreference(x):
def lincom_revreference(x, by):
lincom = -1 * np.identity(len(x))
lincom[0] = 1
lab = get_hypothesis_row_labels(x)
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row 1 - Row {i+1}" for i in range(len(lincom))]
else:
Expand All @@ -104,10 +133,10 @@ def lincom_revreference(x):
return lincom


def lincom_reference(x):
def lincom_reference(x, by):
lincom = np.identity(len(x))
lincom[0, :] = -1
lab = get_hypothesis_row_labels(x)
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+1} - Row 1" for i in range(len(lincom))]
else:
Expand All @@ -120,9 +149,9 @@ def lincom_reference(x):
return lincom


def lincom_revsequential(x):
def lincom_revsequential(x, by):
lincom = np.zeros((len(x), len(x) - 1))
lab = get_hypothesis_row_labels(x)
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+1} - Row {i+2}" for i in range(lincom.shape[1])]
else:
Expand All @@ -136,9 +165,9 @@ def lincom_revsequential(x):
return lincom


def lincom_sequential(x):
def lincom_sequential(x, by):
lincom = np.zeros((len(x), len(x) - 1))
lab = get_hypothesis_row_labels(x)
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+2} - Row {i+1}" for i in range(lincom.shape[1])]
else:
Expand All @@ -152,8 +181,8 @@ def lincom_sequential(x):
return lincom


def lincom_revpairwise(x):
lab_row = get_hypothesis_row_labels(x)
def lincom_revpairwise(x, by):
lab_row = get_hypothesis_row_labels(x, by)
lab_col = []
flag = len(lab_row) == 0 or len(set(lab_row)) != len(lab_row)
mat = []
Expand All @@ -175,8 +204,8 @@ def lincom_revpairwise(x):
return lincom


def lincom_pairwise(x):
lab_row = get_hypothesis_row_labels(x)
def lincom_pairwise(x, by):
lab_row = get_hypothesis_row_labels(x, by)
lab_col = []
flag = len(lab_row) == 0 or len(set(lab_row)) != len(lab_row)
mat = []
Expand Down
2 changes: 1 addition & 1 deletion marginaleffects/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def inner(x):
raise ValueError("Something went wrong")

out = get_by(model, out, newdata=newdata, by=by, wts=wts)
out = get_hypothesis(out, hypothesis=hypothesis)
out = get_hypothesis(out, hypothesis=hypothesis, by=by)
return out

out = inner(model.get_coef())
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def guerry_mod():
return smf.ols("Literacy ~ Pop1831 * Desertion", guerry).fit()


@pytest.fixture(scope="session")
def impartiality_model():
return smf.logit(
"impartial ~ equal * democracy + continent", data=impartiality_df.to_pandas()
).fit()


@pytest.fixture(scope="session")
def penguins_model():
mod = smf.ols(
Expand Down
Loading
Loading