From d43417ce38a62c6235fe95a40ff421cbee3b6963 Mon Sep 17 00:00:00 2001 From: Vincent Arel-Bundock Date: Tue, 26 Dec 2023 09:10:41 -0500 Subject: [PATCH] hypothesis shape --- marginaleffects/hypothesis.py | 29 +++++++++++++++++++++-------- tests/test_jss.py | 24 ++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/marginaleffects/hypothesis.py b/marginaleffects/hypothesis.py index a86e173..b97ef32 100644 --- a/marginaleffects/hypothesis.py +++ b/marginaleffects/hypothesis.py @@ -73,7 +73,7 @@ def get_hypothesis(x, hypothesis): else: raise ValueError(msg) out = lincom_multiply(x, hypmat.to_numpy()) - out = out.with_columns(pl.Series(hypothesis.columns).alias("term")) + out = out.with_columns(pl.Series(hypmat.columns).alias("term")) else: raise ValueError(msg) return out @@ -112,7 +112,10 @@ def lincom_reference(x): lab = [f"Row {i+1} - Row 1" for i in range(len(lincom))] else: lab = [f"{la} - {lab[0]}" for la in lab] - lincom = pl.DataFrame(lincom.T, schema=lab) + if lincom.shape[1] == 1: + lincom = pl.DataFrame(lincom, schema=lab) + else: + lincom = pl.DataFrame(lincom.T, schema=lab) lincom = lincom.select(lab[1:]) return lincom @@ -126,7 +129,10 @@ def lincom_revsequential(x): lab = [f"{lab[i]} - {lab[i+1]}" for i in range(lincom.shape[1])] for i in range(lincom.shape[1]): lincom[i : i + 2, i] = [1, -1] - lincom = pl.DataFrame(lincom.T, schema=lab) + if lincom.shape[1] == 1: + lincom = pl.DataFrame(lincom, schema=lab) + else: + lincom = pl.DataFrame(lincom.T, schema=lab) return lincom @@ -139,7 +145,10 @@ def lincom_sequential(x): lab = [f"{lab[i+1]} - {lab[i]}" for i in range(lincom.shape[1])] for i in range(lincom.shape[1]): lincom[i : i + 2, i] = [-1, 1] - lincom = pl.DataFrame(lincom.T, schema=lab) + if lincom.shape[1] == 1: + lincom = pl.DataFrame(lincom, schema=lab) + else: + lincom = pl.DataFrame(lincom.T, schema=lab) return lincom @@ -159,8 +168,10 @@ def lincom_revpairwise(x): lab_col.append(f"Row {j+1} - Row {i+1}") else: lab_col.append(f"{lab_row[j]} - {lab_row[i]}") - lincom = np.hstack(mat) - lincom = pl.DataFrame(lincom.T, schema=lab_col) + if len(mat) == 1: + lincom = pl.DataFrame(mat[0], schema=lab_col) + else: + lincom = pl.DataFrame(np.hstack(mat).T, schema=lab_col) return lincom @@ -180,6 +191,8 @@ def lincom_pairwise(x): lab_col.append(f"Row {i+1} - Row {j+1}") else: lab_col.append(f"{lab_row[i]} - {lab_row[j]}") - lincom = np.hstack(mat) - lincom = pl.DataFrame(lincom.T, schema=lab_col) + if len(mat) == 1: + lincom = pl.DataFrame(mat[0], schema=lab_col) + else: + lincom = pl.DataFrame(np.hstack(mat).T, schema=lab_col) return lincom diff --git a/tests/test_jss.py b/tests/test_jss.py index 4f3ccff..543e50f 100644 --- a/tests/test_jss.py +++ b/tests/test_jss.py @@ -94,6 +94,25 @@ def test_hypotheses(): assert c["contrast"][0] == 'mean(Democracy) / mean(Autocracy)' +def test_hypothesis_shape(): + m = smf.logit( + "impartial ~ equal * democracy + continent", + data = dat.to_pandas() + ).fit() + + for h in ["reference", "revreference", "sequential", "revsequential", "pairwise", "revpairwise"]: + for b in ["democracy", "continent"]: + c = comparisons(m, + by = b, + variables = {"equal": [30, 90]}, + hypothesis = h, + ) + assert isinstance(c, pl.DataFrame) + if b == "democracy": + assert c.shape[0] == 1 + else: + assert c.shape[0] > 1 + def test_transform(): c1 = avg_comparisons(m, comparison = "lnor") @@ -120,8 +139,9 @@ def test_misc(): # TODO: broken cmp = comparisons(m, by = "democracy", - variables = list(equal = c(30, 90)), - hypothesis = "pairwise") + variables = {"equal": [30, 90]}, + hypothesis = "pairwise", + ) cmp s = slopes(m, variables = "equal", newdata = datagrid(equal=[25, 50], model=m))