diff --git a/marginaleffects/comparisons.py b/marginaleffects/comparisons.py index 7888cee..5ff4942 100644 --- a/marginaleffects/comparisons.py +++ b/marginaleffects/comparisons.py @@ -289,7 +289,7 @@ def applyfun(x, by, wts=None): function=lambda x: applyfun(x, by=by, wts=wts) ) - tmp = get_hypothesis(tmp, hypothesis=hypothesis) + tmp = get_hypothesis(tmp, hypothesis=hypothesis, by=by) return tmp diff --git a/marginaleffects/hypothesis.py b/marginaleffects/hypothesis.py index e3558f6..870c3b5 100644 --- a/marginaleffects/hypothesis.py +++ b/marginaleffects/hypothesis.py @@ -1,5 +1,5 @@ import re - +from itertools import compress import numpy as np import polars as pl @@ -46,7 +46,7 @@ def eval_string_function(vec, hypothesis, rowlabels): # function extracts the estimate column from a data frame and sets it to x. If `hypothesis` argument is a numpy array, it feeds it directly to lincome_multiply. If lincome is a string, it checks if the string is valid, and then calls the corresponding function. -def get_hypothesis(x, hypothesis): +def get_hypothesis(x, hypothesis, by=None): msg = f"Invalid hypothesis argument: {hypothesis}. Valid arguments are: 'reference', 'revreference', 'sequential', 'revsequential', 'pairwise', 'revpairwise' or a numpy array or a float." if hypothesis is None or isinstance(hypothesis, (int, float)): @@ -59,17 +59,17 @@ def get_hypothesis(x, hypothesis): out = eval_string_hypothesis(x, hypothesis, lab=hypothesis) elif isinstance(hypothesis, str): if hypothesis == "reference": - hypmat = lincom_reference(x) + hypmat = lincom_reference(x, by) elif hypothesis == "revreference": - hypmat = lincom_revreference(x) + hypmat = lincom_revreference(x, by) elif hypothesis == "sequential": - hypmat = lincom_sequential(x) + hypmat = lincom_sequential(x, by) elif hypothesis == "revsequential": - hypmat = lincom_revsequential(x) + hypmat = lincom_revsequential(x, by) elif hypothesis == "pairwise": - hypmat = lincom_pairwise(x) + hypmat = lincom_pairwise(x, by) elif hypothesis == "revpairwise": - hypmat = lincom_revpairwise(x) + hypmat = lincom_revpairwise(x, by) else: raise ValueError(msg) out = lincom_multiply(x, hypmat.to_numpy()) @@ -86,15 +86,44 @@ def lincom_multiply(x, lincom): return out -# TODO: improve labeling -def get_hypothesis_row_labels(x): - return ["i" for i in range(len(x))] +def get_hypothesis_row_labels(x, by=None): + pattern = re.compile(r"^(term|by|group|value|contrast|contrast_)$") + lab = [col for col in x.columns if pattern.match(col)] + + # Step 2: Filter columns with more than one unique value + lab = [col for col in lab if len(x[col].unique()) > 1] + + # Step 3: Include additional columns from "by" if provided + if by is not None: + if isinstance(by, str): + by = [by] + lab = [e for e in list(set(lab) | set(by)) if e != "group"] + + # Step 4: If no columns left, return default + if len(lab) == 0: + return [f"{i}" for i in range(len(x))] + + # Step 5: Create a sub-dataframe with selected columns + lab_df = x[lab] + idx = [x[col].n_unique() > 1 for col in lab_df.columns] + + # Step 6: Create row labels by concatenating values + if any(idx): + lab_df = lab_df.select(list(compress(lab_df.columns, idx))) + lab = lab_df.select( + pl.concat_str(lab_df.columns, separator=", ").alias("concatenated") + )["concatenated"].to_list() + + # Step 7: Wrap labels containing "-" in parentheses + lab = [f"({label})" if "-" in label else label for label in lab] + + return lab -def lincom_revreference(x): +def lincom_revreference(x, by): lincom = -1 * np.identity(len(x)) lincom[0] = 1 - lab = get_hypothesis_row_labels(x) + lab = get_hypothesis_row_labels(x, by) if len(lab) == 0 or len(set(lab)) != len(lab): lab = [f"Row 1 - Row {i+1}" for i in range(len(lincom))] else: @@ -104,10 +133,10 @@ def lincom_revreference(x): return lincom -def lincom_reference(x): +def lincom_reference(x, by): lincom = np.identity(len(x)) lincom[0, :] = -1 - lab = get_hypothesis_row_labels(x) + lab = get_hypothesis_row_labels(x, by) if len(lab) == 0 or len(set(lab)) != len(lab): lab = [f"Row {i+1} - Row 1" for i in range(len(lincom))] else: @@ -120,9 +149,9 @@ def lincom_reference(x): return lincom -def lincom_revsequential(x): +def lincom_revsequential(x, by): lincom = np.zeros((len(x), len(x) - 1)) - lab = get_hypothesis_row_labels(x) + lab = get_hypothesis_row_labels(x, by) if len(lab) == 0 or len(set(lab)) != len(lab): lab = [f"Row {i+1} - Row {i+2}" for i in range(lincom.shape[1])] else: @@ -136,9 +165,9 @@ def lincom_revsequential(x): return lincom -def lincom_sequential(x): +def lincom_sequential(x, by): lincom = np.zeros((len(x), len(x) - 1)) - lab = get_hypothesis_row_labels(x) + lab = get_hypothesis_row_labels(x, by) if len(lab) == 0 or len(set(lab)) != len(lab): lab = [f"Row {i+2} - Row {i+1}" for i in range(lincom.shape[1])] else: @@ -152,8 +181,8 @@ def lincom_sequential(x): return lincom -def lincom_revpairwise(x): - lab_row = get_hypothesis_row_labels(x) +def lincom_revpairwise(x, by): + lab_row = get_hypothesis_row_labels(x, by) lab_col = [] flag = len(lab_row) == 0 or len(set(lab_row)) != len(lab_row) mat = [] @@ -175,8 +204,8 @@ def lincom_revpairwise(x): return lincom -def lincom_pairwise(x): - lab_row = get_hypothesis_row_labels(x) +def lincom_pairwise(x, by): + lab_row = get_hypothesis_row_labels(x, by) lab_col = [] flag = len(lab_row) == 0 or len(set(lab_row)) != len(lab_row) mat = [] diff --git a/marginaleffects/predictions.py b/marginaleffects/predictions.py index c705d07..b01b843 100644 --- a/marginaleffects/predictions.py +++ b/marginaleffects/predictions.py @@ -154,7 +154,7 @@ def inner(x): raise ValueError("Something went wrong") out = get_by(model, out, newdata=newdata, by=by, wts=wts) - out = get_hypothesis(out, hypothesis=hypothesis) + out = get_hypothesis(out, hypothesis=hypothesis, by=by) return out out = inner(model.get_coef()) diff --git a/marginaleffects/sanity.py b/marginaleffects/sanity.py index bbc9212..40a74e0 100644 --- a/marginaleffects/sanity.py +++ b/marginaleffects/sanity.py @@ -131,26 +131,26 @@ def sanitize_comparison(comparison, by, wts=None): lab = { "difference": "{hi} - {lo}", - "differenceavg": "mean({hi}) - mean({lo})", - "differenceavgwts": "mean({hi}) - mean({lo})", + "differenceavg": "{hi} - {lo}", + "differenceavgwts": "{hi} - {lo}", "dydx": "dY/dX", "eyex": "eY/eX", "eydx": "eY/dX", "dyex": "dY/eX", - "dydxavg": "mean(dY/dX)", - "eyexavg": "mean(eY/eX)", - "eydxavg": "mean(eY/dX)", - "dyexavg": "mean(dY/eX)", - "dydxavgwts": "mean(dY/dX)", - "eyexavgwts": "mean(eY/eX)", - "eydxavgwts": "mean(eY/dX)", - "dyexavgwts": "mean(dY/eX)", + "dydxavg": "dY/dX", + "eyexavg": "eY/eX", + "eydxavg": "eY/dX", + "dyexavg": "dY/eX", + "dydxavgwts": "dY/dX", + "eyexavgwts": "eY/eX", + "eydxavgwts": "eY/dX", + "dyexavgwts": "dY/eX", "ratio": "{hi} / {lo}", - "ratioavg": "mean({hi}) / mean({lo})", - "ratioavgwts": "mean({hi}) / mean({lo})", + "ratioavg": "{hi} / {lo}", + "ratioavgwts": "{hi} / {lo}", "lnratio": "ln({hi} / {lo})", - "lnratioavg": "ln(mean({hi}) / mean({lo}))", - "lnratioavgwts": "ln(mean({hi}) / mean({lo}))", + "lnratioavg": "ln({hi} / {lo})", + "lnratioavgwts": "ln({hi} / {lo})", "lnor": "ln(odds({hi}) / odds({lo}))", "lnoravg": "ln(odds({hi}) / odds({lo}))", "lnoravgwts": "ln(odds({hi}) / odds({lo}))", diff --git a/tests/conftest.py b/tests/conftest.py index e421fe7..297b287 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,13 @@ def guerry_mod(): return smf.ols("Literacy ~ Pop1831 * Desertion", guerry).fit() +@pytest.fixture(scope="session") +def impartiality_model(): + return smf.logit( + "impartial ~ equal * democracy + continent", data=impartiality_df.to_pandas() + ).fit() + + @pytest.fixture(scope="session") def penguins_model(): mod = smf.ols( diff --git a/tests/images/plot_comparisons/discrete_02.png b/tests/images/plot_comparisons/discrete_02.png index e50a490..da6bf69 100644 Binary files a/tests/images/plot_comparisons/discrete_02.png and b/tests/images/plot_comparisons/discrete_02.png differ diff --git a/tests/images/plot_slopes/by_01.png b/tests/images/plot_slopes/by_01.png index 4a30fab..12ae4a5 100644 Binary files a/tests/images/plot_slopes/by_01.png and b/tests/images/plot_slopes/by_01.png differ diff --git a/tests/test_jss.py b/tests/test_jss.py index 13e42da..f19e6bd 100644 --- a/tests/test_jss.py +++ b/tests/test_jss.py @@ -6,114 +6,178 @@ import pytest from tests.conftest import impartiality_df as dat -m = smf.logit("impartial ~ equal * democracy + continent", data=dat.to_pandas()).fit() - @pytest.mark.skip(reason="to be fixed") -def test_predictions(): - p = predictions(m) +def test_predictions(impartiality_model): + p = predictions(impartiality_model) + assert isinstance(p, pl.DataFrame) - p = predictions(m, newdata=dat.head()) + p = predictions(impartiality_model, newdata=dat.head()) + assert isinstance(p, pl.DataFrame) - p = predictions(m, newdata="mean") + p = predictions(impartiality_model, newdata="mean") + assert isinstance(p, pl.DataFrame) p = predictions( - m, - newdata=datagrid(model=m, democracy=dat["democracy"].unique(), equal=[30, 90]), + impartiality_model, + newdata=datagrid( + model=impartiality_model, + democracy=dat["democracy"].unique(), + equal=[30, 90], + ), ) assert isinstance(p, pl.DataFrame) assert p.shape[0] == 4 - p1 = avg_predictions(m) - p2 = np.mean(m.predict(dat.to_pandas()).to_numpy()) + p1 = avg_predictions(impartiality_model) + p2 = np.mean(impartiality_model.predict(dat.to_pandas()).to_numpy()) assert isinstance(p1, pl.DataFrame) assert p1.shape[0] == 1 assert p1["estimate"][0] == p2 - p = predictions(m, by="democracy") + p = predictions(impartiality_model, by="democracy") assert isinstance(p, pl.DataFrame) assert p.shape[0] == 2 - p = plot_predictions(m, by=["democracy", "continent"]) + p = plot_predictions(impartiality_model, by=["democracy", "continent"]) assert assert_image(p, label="jss_01", file="jss") is None -def test_hypotheses(): +def test_hypotheses(impartiality_model): # hypotheses(m, hypothesis = "continentAsia = continentAmericas") - h = hypotheses(m, hypothesis="b4 = b3") + + h = hypotheses(impartiality_model, hypothesis="b4 = b3") + assert isinstance(h, pl.DataFrame) assert h.shape[0] == 1 - avg_predictions(m, by="democracy", hypothesis="revpairwise") + avg_predictions(impartiality_model, by="democracy", hypothesis="revpairwise") - p = predictions(m, by="democracy", hypothesis="b1 = b0 * 2") + p = predictions(impartiality_model, by="democracy", hypothesis="b1 = b0 * 2") assert isinstance(p, pl.DataFrame) assert p.shape[0] == 1 p = predictions( - m, by="democracy", hypothesis="b1 = b0 * 2", equivalence=[-0.2, 0.2] + impartiality_model, + by="democracy", + hypothesis="b1 = b0 * 2", + equivalence=[-0.2, 0.2], ) assert isinstance(p, pl.DataFrame) assert p.shape[0] == 1 - c = comparisons(m, variables="democracy") + c = comparisons(impartiality_model, variables="democracy") assert isinstance(c, pl.DataFrame) assert c.shape[0] == 166 - c = avg_comparisons(m) + c = avg_comparisons(impartiality_model) assert isinstance(c, pl.DataFrame) assert c.shape[0] == 5 - c = avg_comparisons(m, variables={"equal": 4}) + c = avg_comparisons(impartiality_model, variables={"equal": 4}) assert isinstance(c, pl.DataFrame) assert c.shape[0] == 1 - c = avg_comparisons(m, variables={"equal": "sd"}) + c = avg_comparisons(impartiality_model, variables={"equal": "sd"}) assert isinstance(c, pl.DataFrame) assert c.shape[0] == 1 - c = avg_comparisons(m, variables={"equal": [30, 90]}) + c = avg_comparisons(impartiality_model, variables={"equal": [30, 90]}) assert isinstance(c, pl.DataFrame) assert c.shape[0] == 1 - c = avg_comparisons(m, variables={"equal": "iqr"}) + c = avg_comparisons(impartiality_model, variables={"equal": "iqr"}) assert isinstance(c, pl.DataFrame) assert c.shape[0] == 1 - c = avg_comparisons(m, variables="democracy", comparison="ratio") - assert c["contrast"][0] == "mean(Democracy) / mean(Autocracy)" - - -def test_hypothesis_shape(): - m = smf.logit( - "impartial ~ equal * democracy + continent", data=dat.to_pandas() - ).fit() - - for h in [ - "reference", - "revreference", - "sequential", - "revsequential", - "pairwise", - "revpairwise", - ]: - for b in ["democracy", "continent"]: - c = comparisons( - m, - by=b, - variables={"equal": [30, 90]}, - hypothesis=h, - ) - assert isinstance(c, pl.DataFrame) - if b == "democracy": - assert c.shape[0] == 1 - else: - assert c.shape[0] > 1 - - -def test_transform(): - c1 = avg_comparisons(m, comparison="lnor") - c2 = avg_comparisons(m, comparison="lnor", transform=np.exp) + c = avg_comparisons(impartiality_model, variables="democracy", comparison="ratio") + assert c["contrast"][0] == "Democracy / Autocracy" + + c = avg_comparisons(impartiality_model, variables="democracy", comparison="differenceavg") + assert c["contrast"][0] == "Democracy - Autocracy" + + +@pytest.mark.parametrize( + "h, label", + [ + ( + "reference", + { + "democracy": "Democracy - Autocracy", + "continent": ["Americas - Africa", "Asia - Africa", "Europe - Africa"], + }, + ), + ( + "revreference", + { + "democracy": "Autocracy - Democracy", + "continent": ["Africa - Americas", "Africa - Asia", "Africa - Europe"], + }, + ), + ( + "sequential", + { + "democracy": "Democracy - Autocracy", + "continent": ["Americas - Africa", "Asia - Americas", "Europe - Asia"], + }, + ), + ( + "revsequential", + { + "democracy": "Autocracy - Democracy", + "continent": ["Africa - Americas", "Americas - Asia", "Asia - Europe"], + }, + ), + ( + "pairwise", + { + "democracy": "Autocracy - Democracy", + "continent": [ + "Africa - Americas", + "Africa - Asia", + "Africa - Europe", + "Americas - Asia", + "Americas - Europe", + "Asia - Europe", + ], + }, + ), + ( + "revpairwise", + { + "democracy": "Democracy - Autocracy", + "continent": [ + "Americas - Africa", + "Asia - Africa", + "Europe - Africa", + "Asia - Americas", + "Europe - Americas", + "Europe - Asia", + ], + }, + ), + ], +) +def test_hypothesis_shape_and_row_labels(h, label, impartiality_model): + for b in ["democracy", "continent"]: + c = comparisons( + impartiality_model, + by=b, + variables={"equal": [30, 90]}, + hypothesis=h, + ) + assert isinstance(c, pl.DataFrame) + if b == "democracy": + assert c.shape[0] == 1 + assert c["term"][0] == label[b] + else: + assert c.shape[0] > 1 + assert (c["term"] == label[b]).all() + + +def test_transform(impartiality_model): + c1 = avg_comparisons(impartiality_model, comparison="lnor") + c2 = avg_comparisons(impartiality_model, comparison="lnor", transform=np.exp) all(np.exp(c1["estimate"]) == c2["estimate"]) @@ -122,13 +186,13 @@ def test_transform(): # comparison = lambda hi, lo: np.mean(hi) / np.ean(lo)) -def test_misc(): - cmp = comparisons(m, by="democracy", variables={"equal": [30, 90]}) +def test_misc(impartiality_model): + cmp = comparisons(impartiality_model, by="democracy", variables={"equal": [30, 90]}) assert isinstance(cmp, pl.DataFrame) assert cmp.shape[0] == 2 cmp = comparisons( - m, + impartiality_model, by="democracy", variables={"equal": [30, 90]}, hypothesis="pairwise", @@ -136,19 +200,24 @@ def test_misc(): assert isinstance(cmp, pl.DataFrame) assert cmp.shape[0] == 1 - s = slopes(m, variables="equal", newdata=datagrid(equal=[25, 50], model=m)) + s = slopes( + impartiality_model, + variables="equal", + newdata=datagrid(equal=[25, 50], model=impartiality_model), + ) + assert isinstance(s, pl.DataFrame) assert s.shape[0] == 2 - s = avg_slopes(m, variables="equal") + s = avg_slopes(impartiality_model, variables="equal") assert isinstance(s, pl.DataFrame) assert s.shape[0] == 1 - s = slopes(m, variables="equal", newdata="mean") + s = slopes(impartiality_model, variables="equal", newdata="mean") assert isinstance(s, pl.DataFrame) assert s.shape[0] == 1 - s = avg_slopes(m, variables="equal", slope="eyex") + s = avg_slopes(impartiality_model, variables="equal", slope="eyex") assert isinstance(s, pl.DataFrame) assert s.shape[0] == 1 @@ -208,12 +277,12 @@ def test_titanic(): assert c.shape[0] == 1 -def test_python_section(): - p = avg_predictions(m, by="continent") +def test_python_section(impartiality_model): + p = avg_predictions(impartiality_model, by="continent") assert isinstance(p, pl.DataFrame) assert p.shape[0] == 4 - s = slopes(m, newdata="mean") + s = slopes(impartiality_model, newdata="mean") assert isinstance(s, pl.DataFrame) assert s.shape[0] == 5 diff --git a/uv.lock b/uv.lock index 560594c..b32d731 100644 --- a/uv.lock +++ b/uv.lock @@ -411,7 +411,7 @@ wheels = [ [[package]] name = "marginaleffects" -version = "0.0.13.1" +version = "0.0.14" source = { virtual = "." } dependencies = [ { name = "narwhals" },