Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joint hypotheses #94

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions marginaleffects/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

class MarginaleffectsDataFrame(pl.DataFrame):
def __init__(
self, data=None, by=None, conf_level=0.95, jacobian=None, newdata=None
self,
data=None,
by=None,
conf_level=0.95,
jacobian=None,
newdata=None,
mapping=None,
print_head="",
):
if isinstance(data, pl.DataFrame):
self._df = data._df
Expand All @@ -14,42 +21,52 @@ def __init__(
self.datagrid_explicit = newdata.datagrid_explicit
else:
self.datagrid_explicit = []

self.print_head = print_head

default_mapping = {
"term": "Term",
"contrast": "Contrast",
"estimate": "Estimate",
"std_error": "Std.Error",
"statistic": "z",
"p_value": "P(>|z|)",
"s_value": "S",
}
if mapping is None:
self.mapping = default_mapping
else:
for key, val in default_mapping.items():
if key not in mapping.keys():
mapping[key] = val
self.mapping = mapping

return
super().__init__(data)

def __str__(self):
mapping = {
"term": "Term",
"contrast": "Contrast",
"estimate": "Estimate",
"std_error": "Std.Error",
"statistic": "z",
"p_value": "P(>|z|)",
"s_value": "S",
}

if hasattr(self, "conf_level"):
mapping["conf_low"] = f"{(1 - self.conf_level) / 2 * 100:.1f}%"
mapping["conf_high"] = f"{(1 - (1 - self.conf_level) / 2) * 100:.1f}%"
self.mapping["conf_low"] = f"{(1 - self.conf_level) / 2 * 100:.1f}%"
self.mapping["conf_high"] = f"{(1 - (1 - self.conf_level) / 2) * 100:.1f}%"
else:
mapping["conf_low"] = "["
mapping["conf_high"] = "]"
self.mapping["conf_low"] = "["
self.mapping["conf_high"] = "]"

if hasattr(self, "by"):
if self.by is None:
valid = list(mapping.keys())
valid = list(self.mapping.keys())
elif self.by is True:
valid = list(mapping.keys())
valid = list(self.mapping.keys())
elif self.by is False:
valid = list(mapping.keys())
valid = list(self.mapping.keys())
elif isinstance(self.by, list):
valid = self.by + list(mapping.keys())
valid = self.by + list(self.mapping.keys())
elif isinstance(self.by, str):
valid = [self.by] + list(mapping.keys())
valid = [self.by] + list(self.mapping.keys())
else:
raise ValueError("by must be None or a string or a list of strings")
else:
valid = list(mapping.keys())
valid = list(self.mapping.keys())

valid = self.datagrid_explicit + valid
valid = [x for x in valid if x in self.columns]
Expand All @@ -58,11 +75,12 @@ def __str__(self):
valid = dict.fromkeys(valid)
valid = list(valid.keys())

mapping = {key: mapping[key] for key in mapping if key in valid}
tmp = self.select(valid).rename(mapping)
out = self.print_head
self.mapping = {key: self.mapping[key] for key in self.mapping if key in valid}
tmp = self.select(valid).rename(self.mapping)
for col in tmp.columns:
if tmp[col].dtype.is_numeric():
tmp = tmp.with_columns(pl.col(col).map_elements(lambda x: f"{x:.3g}"))
out = tmp.__str__()
out += tmp.__str__()
out = out + f"\n\nColumns: {', '.join(self.columns)}\n"
return out
18 changes: 16 additions & 2 deletions marginaleffects/hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
from .sanitize_model import sanitize_model
from .uncertainty import get_jacobian, get_se, get_z_p_ci
from .utils import sort_columns
from .hypotheses_joint import joint_hypotheses


def hypotheses(
model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=None, eps_vcov=None
model,
hypothesis=None,
conf_level=0.95,
vcov=True,
equivalence=None,
eps_vcov=None,
joint=False,
joint_test="f",
):
"""
(Non-)Linear Tests for Null Hypotheses, Joint Hypotheses, Equivalence, Non Superiority, and Non Inferiority.
Expand Down Expand Up @@ -73,9 +81,15 @@ def hypotheses(
"""

model = sanitize_model(model)
V = sanitize_vcov(vcov, model)

if joint:
out = joint_hypotheses(
model, joint_index=joint, joint_test=joint_test, hypothesis=hypothesis
)
return out

hypothesis_null = sanitize_hypothesis_null(hypothesis)
V = sanitize_vcov(vcov, model)

# estimands
def fun(x):
Expand Down
109 changes: 109 additions & 0 deletions marginaleffects/hypotheses_joint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
import scipy.stats as stats
import polars as pl
import pandas as pd

from .sanity import sanitize_hypothesis_null
from .classes import MarginaleffectsDataFrame


def joint_hypotheses(obj, joint_index=None, joint_test="f", hypothesis=0):
assert joint_test in ["f", "chisq"], "`joint_test` must be `f` or `chisq`"

if isinstance(obj, pd.DataFrame):
obj = pl.DataFrame(obj)

# theta_hat: P x 1 vector of estimated parameters
theta_hat = obj.get_coef()

var_names = obj.get_variables_names()

if len(theta_hat) == len(var_names) + 1:
var_names = ["Intercept"] + var_names

if isinstance(joint_index, bool):
joint_index = range(len(theta_hat))
else:
if not isinstance(joint_index, list):
joint_index = [joint_index]
if all(isinstance(i, str) for i in joint_index):
joint_index = [
i for i in range(len(var_names)) if var_names[i] in joint_index
]
assert min(joint_index) >= 0 and max(joint_index) <= len(
var_names
), "`joint_index` contain invalid indices"

V_hat = obj.get_vcov()

# R: Q x P matrix for testing Q hypotheses on P parameters
# build R matrix based on joint_index
R = np.zeros((len(joint_index), len(theta_hat)))
for i in range(len(joint_index)):
R[i, joint_index[i]] = 1

if not isinstance(hypothesis, list):
if hypothesis is None:
hypothesis = 0
hypothesis = np.ones(R.shape[0]) * hypothesis
hypothesis = [sanitize_hypothesis_null(h) for h in hypothesis]

# Calculate the difference between R*theta_hat and r
diff = R @ theta_hat - hypothesis

# Calculate the inverse of R*(V_hat/n)*R'
inv = np.linalg.inv(R @ V_hat @ R.T)

# Calculate the Wald test statistic
if joint_test == "f":
wald_statistic = (
diff.T @ inv @ diff / R.shape[0]
) # Q is the number of rows in R
elif joint_test == "chisq":
wald_statistic = (
diff.T @ inv @ diff
) # Not normalized for chi-squared joint_test

# Degrees of freedom
df1 = R.shape[0] # Q
df2 = obj.get_df() # n - P

# Calculate the p-value
if joint_test == "f":
p_value = 1 - stats.f.cdf(wald_statistic, df1, df2)
elif joint_test == "chisq":
p_value = 1 - stats.chi2.cdf(wald_statistic, df1)
df2 = None

# Return the Wald joint_test statistic and p-value
mapping = {}
if joint_test == "f":
out = pl.DataFrame({"statistic": wald_statistic})
mapping["statistic"] = "F"
mapping["p_value"] = "P(>|F|)"
elif joint_test == "chisq":
out = pl.DataFrame({"statistic": wald_statistic})
mapping["statistic"] = "ChiSq"
mapping["p_value"] = "P(>|ChiSq|)"
out = out.with_columns(p_value=p_value)

# degrees of freedom print
if joint_test == "f":
out = out.with_columns(df1=df1, df2=df2)
out = out.cast({"df1": pl.Int64, "df2": pl.Int64})
mapping["df1"] = "Df 1"
mapping["df2"] = "Df 2"
elif joint_test == "chisq":
out = out.with_columns(df=df1)
out = out.cast({"df": pl.Int64})
mapping["df"] = "Df"

# Create the print_head string
print_head = "Joint hypothesis test:\n"
for i, j in enumerate(joint_index):
print_head += var_names[j] + f" = {hypothesis[i]}\n"
print_head += "\n"

out = MarginaleffectsDataFrame(out, mapping=mapping, print_head=print_head)

return out
16 changes: 9 additions & 7 deletions marginaleffects/model_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ def get_variables_names(self, variables=None, newdata=None):
if variables is None:
formula = self.formula
columns = self.modeldata.columns
variables = list(
{
var
for var in columns
if re.search(rf"\b{re.escape(var)}\b", formula.split("~")[1])
}
)
order = {}
for var in columns:
match = re.search(rf"\b{re.escape(var)}\b", formula.split("~")[1])
if match:
order[var] = match.start()
variables = sorted(order, key=lambda i: order[i])

if isinstance(variables, (str, dict)):
variables = [variables] if isinstance(variables, str) else variables
Expand Down Expand Up @@ -108,3 +107,6 @@ def get_predict(self, params, newdata: pl.DataFrame):

def get_formula(self):
return self.model.model.formula

def get_df(self):
return self.model.df_resid
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_01.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
7.72769631208158,0.00212618448310864,2,28
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_02.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df"
6.77031659722067,0.0338722799864077,2
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_03.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
7.72769631208158,0.00212618448310864,2,28
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_04.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
811073.258088452,0,3,28
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_05.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
8520340.17301277,0,3,28
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_06.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
11.1162545397227,5.00882326577301e-05,3,29
2 changes: 2 additions & 0 deletions tests/r/test_hypotheses_joint_07.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"statistic","p.value","df1","df2"
12.8439392466833,0.000101239587237067,2,29
59 changes: 59 additions & 0 deletions tests/test_hypotheses_joint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import polars as pl
import statsmodels.formula.api as smf
from polars.testing import assert_frame_equal

from marginaleffects import *

mtcars = pl.read_csv(
"https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv"
)

mod = smf.ols("am ~ hp + wt + disp", data=mtcars).fit()

mod_without_intercept = smf.ols("am ~ 0 + hp + wt + disp", data=mtcars).fit()


def test_hypotheses_joint():
hypo_py = hypotheses(mod, joint=["hp", "wt"])
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_01.csv").rename(
{"p.value": "p_value"}
)
assert_frame_equal(hypo_py, hypo_r)

hypo_py = hypotheses(mod, joint=["hp", "disp"], joint_test="chisq")
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_02.csv").rename(
{"p.value": "p_value"}
)
assert_frame_equal(hypo_py, hypo_r)

hypo_py = hypotheses(mod, joint=[1, 2])
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_03.csv").rename(
{"p.value": "p_value"}
)
assert_frame_equal(hypo_py, hypo_r)

hypo_py = hypotheses(mod, joint=[0, 1, 2], hypothesis=[1, 2, 3])
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_04.csv").rename(
{"p.value": "p_value"}
)
hypo_r = hypo_r.cast({"p_value": pl.Float64})
assert_frame_equal(hypo_py, hypo_r, check_exact=False, atol=0.0001)

hypo_py = hypotheses(mod, joint=["Intercept", "disp", "wt"], hypothesis=4)
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_05.csv").rename(
{"p.value": "p_value"}
)
hypo_r = hypo_r.cast({"p_value": pl.Float64})
assert_frame_equal(hypo_py, hypo_r, check_exact=False, atol=0.0001)

hypo_py = hypotheses(mod_without_intercept, joint=[0, 1, 2])
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_06.csv").rename(
{"p.value": "p_value"}
)
assert_frame_equal(hypo_py, hypo_r)

hypo_py = hypotheses(mod_without_intercept, joint=["hp", "wt"])
hypo_r = pl.read_csv("tests/r/test_hypotheses_joint_07.csv").rename(
{"p.value": "p_value"}
)
assert_frame_equal(hypo_py, hypo_r)
Loading