Skip to content

Commit

Permalink
hypothesis shape
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Dec 26, 2023
1 parent a6c4f48 commit d43417c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
29 changes: 21 additions & 8 deletions marginaleffects/hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_hypothesis(x, hypothesis):
else:
raise ValueError(msg)
out = lincom_multiply(x, hypmat.to_numpy())
out = out.with_columns(pl.Series(hypothesis.columns).alias("term"))
out = out.with_columns(pl.Series(hypmat.columns).alias("term"))
else:
raise ValueError(msg)
return out
Expand Down Expand Up @@ -112,7 +112,10 @@ def lincom_reference(x):
lab = [f"Row {i+1} - Row 1" for i in range(len(lincom))]
else:
lab = [f"{la} - {lab[0]}" for la in lab]
lincom = pl.DataFrame(lincom.T, schema=lab)
if lincom.shape[1] == 1:
lincom = pl.DataFrame(lincom, schema=lab)
else:
lincom = pl.DataFrame(lincom.T, schema=lab)
lincom = lincom.select(lab[1:])
return lincom

Expand All @@ -126,7 +129,10 @@ def lincom_revsequential(x):
lab = [f"{lab[i]} - {lab[i+1]}" for i in range(lincom.shape[1])]
for i in range(lincom.shape[1]):
lincom[i : i + 2, i] = [1, -1]
lincom = pl.DataFrame(lincom.T, schema=lab)
if lincom.shape[1] == 1:
lincom = pl.DataFrame(lincom, schema=lab)
else:
lincom = pl.DataFrame(lincom.T, schema=lab)
return lincom


Expand All @@ -139,7 +145,10 @@ def lincom_sequential(x):
lab = [f"{lab[i+1]} - {lab[i]}" for i in range(lincom.shape[1])]
for i in range(lincom.shape[1]):
lincom[i : i + 2, i] = [-1, 1]
lincom = pl.DataFrame(lincom.T, schema=lab)
if lincom.shape[1] == 1:
lincom = pl.DataFrame(lincom, schema=lab)
else:
lincom = pl.DataFrame(lincom.T, schema=lab)
return lincom


Expand All @@ -159,8 +168,10 @@ def lincom_revpairwise(x):
lab_col.append(f"Row {j+1} - Row {i+1}")
else:
lab_col.append(f"{lab_row[j]} - {lab_row[i]}")
lincom = np.hstack(mat)
lincom = pl.DataFrame(lincom.T, schema=lab_col)
if len(mat) == 1:
lincom = pl.DataFrame(mat[0], schema=lab_col)
else:
lincom = pl.DataFrame(np.hstack(mat).T, schema=lab_col)
return lincom


Expand All @@ -180,6 +191,8 @@ def lincom_pairwise(x):
lab_col.append(f"Row {i+1} - Row {j+1}")
else:
lab_col.append(f"{lab_row[i]} - {lab_row[j]}")
lincom = np.hstack(mat)
lincom = pl.DataFrame(lincom.T, schema=lab_col)
if len(mat) == 1:
lincom = pl.DataFrame(mat[0], schema=lab_col)
else:
lincom = pl.DataFrame(np.hstack(mat).T, schema=lab_col)
return lincom
24 changes: 22 additions & 2 deletions tests/test_jss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def test_hypotheses():
assert c["contrast"][0] == 'mean(Democracy) / mean(Autocracy)'


def test_hypothesis_shape():
m = smf.logit(
"impartial ~ equal * democracy + continent",
data = dat.to_pandas()
).fit()

for h in ["reference", "revreference", "sequential", "revsequential", "pairwise", "revpairwise"]:
for b in ["democracy", "continent"]:
c = comparisons(m,
by = b,
variables = {"equal": [30, 90]},
hypothesis = h,
)
assert isinstance(c, pl.DataFrame)
if b == "democracy":
assert c.shape[0] == 1
else:
assert c.shape[0] > 1


def test_transform():
c1 = avg_comparisons(m, comparison = "lnor")
Expand All @@ -120,8 +139,9 @@ def test_misc():
# TODO: broken
cmp = comparisons(m,
by = "democracy",
variables = list(equal = c(30, 90)),
hypothesis = "pairwise")
variables = {"equal": [30, 90]},
hypothesis = "pairwise",
)
cmp

s = slopes(m, variables = "equal", newdata = datagrid(equal=[25, 50], model=m))
Expand Down

0 comments on commit d43417c

Please sign in to comment.