Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
artiom-matvei committed Nov 19, 2024
1 parent f0a5f30 commit 24f7d1c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 33 deletions.
30 changes: 30 additions & 0 deletions marginaleffects/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,36 @@ def get_jacobian(func, coefs, eps_vcov=None):


def get_se(J, V):
# J are different in python versus R
'''Python
array([[ 3.26007929e-03, -1.11398030e+05, -1.07005746e+07,
7.82170746e+07, 2.78495302e+07, 1.34757804e+00],
[ 3.00860803e+05, 1.11398009e+05, 3.15899457e+07,
4.00666771e+07, 1.42659178e+07, -5.02687187e+00],
[-3.00860806e+05, 2.03275902e-02, -2.08893711e+07,
-1.18283752e+08, -4.21154480e+07, 3.67929389e+00],
...,
[ 1.94566192e-02, 6.31280874e+04, 6.06390551e+06,
6.95824154e+07, 2.47751212e+07, 4.27175339e-01],
[ 1.19234156e+05, -6.31280929e+04, 1.25194134e+07,
-2.27053711e+07, -8.08434379e+06, -1.51166669e+00],
[-1.19234176e+05, 5.47722901e-03, -1.85833189e+07,
-4.68770444e+07, -1.66907774e+07, 1.08449133e+00]])
'''
'''R
[,1] [,2] [,3] [,4] [,5] [,6]
[1,] 4.211085e-03 0.1129098874 7.622065e-01 6.449953e-03 2.390449e-01 1.167441e+00
[2,] 2.317994e-03 -0.0002889116 4.311471e-01 1.240511e-02 4.599770e-01 2.307350e+00
[3,] -1.458913e-02 -0.7474941310 -2.844880e+00 2.018562e-02 7.334876e-01 3.936196e+00
[4,] -1.768363e-02 -0.7516309380 -3.412940e+00 1.977302e-02 5.853414e-01 3.816193e+00
[5,] -6.148701e-03 -0.3694016210 -1.168253e+00 1.883557e-02 6.798748e-01 3.578758e+00
[6,] 3.851085e-03 0.0970890213 6.970463e-01 6.842971e-03 2.519078e-01 1.238578e+00
[7,] -1.807982e-02 -0.8520217380 -3.525564e+00 2.090703e-02 7.173019e-01 4.076872e+00
[8,] -1.324527e-02 -0.5110488046 -2.556336e+00 1.326720e-02 2.654043e-01 2.560569e+00
[9,] 7.328301e-03 0.1888463791 1.392377e+00 1.094683e-02 4.361391e-01 2.079898e+00
[10,] -4.511196e-03 -0.2639279320 -8.390823e-01 1.642779e-02 5.682850e-01 3.055570e+00
...
'''
se = np.sqrt(np.sum((J @ V) * J, axis=1))
return se

Expand Down
84 changes: 51 additions & 33 deletions tests/statsmodels/test_statsmodels_mnlogit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,59 @@ def test_predictions_02():
assert_series_equal(known["estimate"], unknown["estimate"], rtol=1e-1)


# def test_comparisons_01():
# penguins_clean = penguins_with_nulls.select(
# ["island", "bill_length_mm", "flipper_length_mm"]
# ).drop_nulls()

# # Define island categories and create a mapping
# island_categories = ["Biscoe", "Dream", "Torgersen"]
# island_mapping = {island: code for code, island in enumerate(island_categories)}

# # Map 'island' to integer codes
# penguins_clean = penguins_clean.with_columns(
# pl.col("island").replace_strict(island_mapping)
# )
def test_comparisons_01():
penguins_clean = penguins_with_nulls.select(
["island", "bill_length_mm", "flipper_length_mm"]
).drop_nulls()

# mod = smf.mnlogit(
# "island ~ bill_length_mm + flipper_length_mm", data=penguins_clean
# ).fit()
# unknown = (
# comparisons(mod)
# .with_columns(pl.col("group").replace(island_mapping))
# .sort(["rowid", "term", "group"])
# )
# known = (
# pl.read_csv("tests/r/test_statsmodels_mnlogit_comparisons_01.csv")
# .with_columns(pl.col("group").replace(island_mapping))
# .sort(["rowid", "term", "group"])
# )
# print(known)
# print(unknown)
# assert_series_equal(known["estimate"].head(), unknown["estimate"].head(), rtol=2)

# unknown = comparisons(mod)
# known = pl.read_csv("tests/r/test_statsmodels_mnlogit_comparisons_01.csv")
# assert_series_equal(known["estimate"], unknown["estimate"], rtol=1e-2)
# Define island categories and create a mapping
island_categories = ["Biscoe", "Dream", "Torgersen"]
island_mapping = {island: code for code, island in enumerate(island_categories)}

# Map 'island' to integer codes
penguins_clean = penguins_clean.with_columns(
pl.col("island").replace_strict(island_mapping)
)

mod = smf.mnlogit(
"island ~ bill_length_mm + flipper_length_mm", data=penguins_clean
).fit()
unknown = (
comparisons(mod)
.with_columns(pl.col("group").replace(island_mapping))
.sort(["rowid", "term", "group"])
)
known = (
pl.read_csv("tests/r/test_statsmodels_mnlogit_comparisons_01.csv")
.with_columns(pl.col("group").replace(island_mapping))
.sort(["rowid", "term", "group"])
)
new_column_names = {col: col.replace('.', '_') for col in known.columns}
known = known.rename(new_column_names)
print(known.head())
print(unknown.head())
print(compare_polars_tables(known, unknown, index=0))
assert_series_equal(known["estimate"].head(), unknown["estimate"].head(), rtol=2)

unknown = comparisons(mod)
known = pl.read_csv("tests/r/test_statsmodels_mnlogit_comparisons_01.csv")
assert_series_equal(known["estimate"], unknown["estimate"], rtol=1e-2)

# Function to print visual comparison
def compare_polars_tables(known, unknown, index=0):
headers = ["Column", "Table known Value", "Table unknown Value", "Difference"]
row_format = "{:<25} {:<25} {:<25} {:<15}"

print(row_format.format(*headers))
print("-" * 60)

for col in known.columns:
val1 = known[col][index]
val2 = unknown[col][index]
difference = "Yes" if val1 != val2 else "No"

print(row_format.format(col, val1, val2, difference))
print(row_format.format("index", index, index, "No"))

# # @pytest.mark.skip(reason="statsmodels vcov is weird")
# def test_comparisons_02():
Expand Down

0 comments on commit 24f7d1c

Please sign in to comment.