diff --git a/NEWS.md b/NEWS.md index 46470c1..5b722af 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# dev + +* New `eps_vcov` argument to control the step size in the computation of the Jacobian used for standard errors. +* Refactor and several bug fixes in the `plot_*()` functions. + # 0.0.6 * `hypothesis` accepts a float or integer to specify a different null hypothesis. diff --git a/marginaleffects/comparisons.py b/marginaleffects/comparisons.py index 3402807..86fd8e4 100644 --- a/marginaleffects/comparisons.py +++ b/marginaleffects/comparisons.py @@ -35,6 +35,7 @@ def comparisons( equivalence=None, transform=None, eps=1e-4, + eps_vcov=None, ): """ `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, etc. @@ -275,7 +276,7 @@ def outer(x): out = outer(model.coef) if vcov is not None and vcov is not False: - J = get_jacobian(func=outer, coefs=model.coef) + J = get_jacobian(func=outer, coefs=model.coef, eps_vcov=eps_vcov) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) out = get_z_p_ci( diff --git a/marginaleffects/hypotheses.py b/marginaleffects/hypotheses.py index bb5b881..035e045 100644 --- a/marginaleffects/hypotheses.py +++ b/marginaleffects/hypotheses.py @@ -9,7 +9,9 @@ from .utils import sort_columns -def hypotheses(model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=None): +def hypotheses( + model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=None, eps_vcov=None +): """ (Non-)Linear Tests for Null Hypotheses, Joint Hypotheses, Equivalence, Non Superiority, and Non Inferiority. @@ -83,7 +85,7 @@ def fun(x): out = fun(model.coef) if vcov is not None: - J = get_jacobian(fun, model.coef) + J = get_jacobian(fun, model.coef, eps_vcov=eps_vcov) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) out = get_z_p_ci( diff --git a/marginaleffects/plot_common.py b/marginaleffects/plot_common.py index 06894b1..09d1a84 100644 --- a/marginaleffects/plot_common.py +++ b/marginaleffects/plot_common.py @@ -58,24 +58,14 @@ def dt_on_condition(model, condition): modeldata[key], [0, 25, 50, 75, 100], method="midpoint" ).tolist() - elif variable_type == "boolean" or variable_type == "character": - to_datagrid[key] = modeldata[key].unique().to_list() + elif variable_type in ["boolean", "character", "binary"]: + to_datagrid[key] = modeldata[key].unique().sort().to_list() assert ( len(to_datagrid[key]) <= 10 ), f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values." - dt_code = "datagrid(newdata=modeldata" - for key, value in to_datagrid.items(): - dt_code += ", " + key + "=" - if isinstance(value, str): - dt_code += "'" + value + "'" - else: - dt_code += str(value) - dt_code += ")" - - # TODO: this is weird. I'd prefer someting more standard than evaluating text - exec("global dt; dt = " + dt_code) - + to_datagrid["newdata"] = modeldata + dt = datagrid(**to_datagrid) return dt # noqa: F821 @@ -99,7 +89,7 @@ def plotter(dt, x_name, x_type, fig=None, axe=None, label=None, color=None): plot_obj.fill_between(x, y_low, y_high, color=color, alpha=0.2) plot_obj.plot(x, y, color=color, label=label) - elif x_type == "character" or x_type == "boolean": + elif x_type in ["character", "binary", "boolean"]: y_low = np.absolute(y - y_low) y_high = np.absolute(y_high - y) if color is None: diff --git a/marginaleffects/plot_comparisons.py b/marginaleffects/plot_comparisons.py index e3646d9..7a6ea36 100644 --- a/marginaleffects/plot_comparisons.py +++ b/marginaleffects/plot_comparisons.py @@ -85,64 +85,42 @@ def plot_comparisons( wts is not None and not by ), "The `wts` argument requires a `by` argument." - if by: - if newdata is not None: - dt = comparisons( - model, - variables=variables, - newdata=newdata, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps, - ) - else: - dt = comparisons( - model, - variables=variables, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps, - ) - - var_list = [by] if isinstance(by, str) else by - - elif condition is not None: - dt_condition = dt_on_condition(model, condition) - if isinstance(condition, str): - var_list = [condition] - elif isinstance(condition, list): - var_list = condition - elif isinstance(condition, dict): - var_list = list(condition.keys()) - dt = comparisons( - model, - variables=variables, - newdata=dt_condition, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=var_list, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps, - ) - - dt = dt.drop_nulls(var_list[0]) - dt = dt.sort(var_list[0]) + if condition is not None: + newdata = dt_on_condition(model, condition) + + dt = comparisons( + model, + variables=variables, + newdata=newdata, + comparison=comparison, + vcov=vcov, + conf_level=conf_level, + by=by, + wts=wts, + hypothesis=hypothesis, + equivalence=equivalence, + transform=transform, + eps=eps, + ) + + if not draw: + return dt + + if isinstance(condition, str): + var_list = [condition] + elif isinstance(condition, list): + var_list = condition + elif isinstance(condition, dict): + var_list = list(condition.keys()) + elif isinstance(by, str): + var_list = [by] + elif isinstance(by, list): + var_list = by + elif isinstance(by, dict): + var_list = list(by.keys()) + + # not sure why these get appended + var_list = [x for x in var_list if x not in ["newdata", "model"]] if not draw: return dt diff --git a/marginaleffects/plot_predictions.py b/marginaleffects/plot_predictions.py index c3c3e29..1d26939 100644 --- a/marginaleffects/plot_predictions.py +++ b/marginaleffects/plot_predictions.py @@ -68,58 +68,44 @@ def plot_predictions( not by and newdata is not None ), "The `newdata` argument requires a `by` argument." - assert (condition is None and by) or ( - condition is not None and not by - ), "One of the `condition` and `by` arguments must be supplied, but not both." - assert not ( wts is not None and not by ), "The `wts` argument requires a `by` argument." - if by: - if newdata is not None: - dt = predictions( - model, - by=by, - newdata=newdata, - conf_level=conf_level, - vcov=vcov, - transform=transform, - wts=wts, - ) - else: - dt = predictions( - model, - by=by, - conf_level=conf_level, - vcov=vcov, - transform=transform, - wts=wts, - ) - - var_list = [by] if isinstance(by, str) else by + assert not ( + condition is None and by is None + ), "One of the `condition` and `by` arguments must be supplied, but not both." if condition is not None: - dt_condition = dt_on_condition(model, condition) - if isinstance(condition, str): - var_list = [condition] - elif isinstance(condition, list): - var_list = condition - elif isinstance(condition, dict): - var_list = list(condition.keys()) - dt = predictions( - model, - by=var_list, - newdata=dt_condition, - conf_level=conf_level, - vcov=vcov, - transform=transform, - ) - - dt = dt.drop_nulls(var_list[0]) - dt = dt.sort(var_list[0]) + newdata = dt_on_condition(model, condition) + + dt = predictions( + model, + by=by, + newdata=newdata, + conf_level=conf_level, + vcov=vcov, + transform=transform, + wts=wts, + ) if not draw: return dt - return plot_common(dt, model.response_name, var_list) + if isinstance(condition, str): + var_list = [condition] + elif isinstance(condition, list): + var_list = condition + elif isinstance(condition, dict): + var_list = list(condition.keys()) + elif isinstance(by, str): + var_list = [by] + elif isinstance(by, list): + var_list = by + elif isinstance(by, dict): + var_list = list(by.keys()) + + # not sure why these get appended + var_list = [x for x in var_list if x not in ["newdata", "model"]] + + return plot_common(dt, model.response_name, var_list=var_list) diff --git a/marginaleffects/plot_slopes.py b/marginaleffects/plot_slopes.py index d48b3b6..6a20a37 100644 --- a/marginaleffects/plot_slopes.py +++ b/marginaleffects/plot_slopes.py @@ -14,6 +14,8 @@ def plot_slopes( by=False, wts=None, draw=True, + eps=1e-4, + eps_vcov=None, ): """ Plot slopes on the y-axis against values of one or more predictors (x-axis, colors/shapes, and facets). @@ -77,54 +79,51 @@ def plot_slopes( wts is not None and not by ), "The `wts` argument requires a `by` argument." - if by: - if newdata is not None: - dt = slopes( - model, - variables=variables, - newdata=newdata, - slope=slope, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - ) - else: - dt = slopes( - model, - variables=variables, - slope=slope, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - ) - - var_list = [by] if isinstance(by, str) else by - - elif condition is not None: - dt_condition = dt_on_condition(model, condition) - if isinstance(condition, str): - var_list = [condition] - elif isinstance(condition, list): - var_list = condition - elif isinstance(condition, dict): - var_list = list(condition.keys()) - dt = slopes( - model, - variables=variables, - newdata=dt_condition, - slope=slope, - vcov=vcov, - conf_level=conf_level, - by=var_list, - wts=wts, - ) - - dt = dt.drop_nulls(var_list[0]) - dt = dt.sort(var_list[0]) + assert not ( + not by and newdata is not None + ), "The `newdata` argument requires a `by` argument." + + assert not ( + wts is not None and not by + ), "The `wts` argument requires a `by` argument." + + assert not ( + condition is None and by is None + ), "One of the `condition` and `by` arguments must be supplied, but not both." + + if condition is not None: + newdata = dt_on_condition(model, condition) + + dt = slopes( + model, + variables=variables, + newdata=newdata, + slope=slope, + vcov=vcov, + conf_level=conf_level, + by=by, + wts=wts, + eps=eps, + eps_vcov=eps_vcov, + ) if not draw: return dt + if isinstance(condition, str): + var_list = [condition] + elif isinstance(condition, list): + var_list = condition + elif isinstance(condition, dict): + var_list = list(condition.keys()) + elif isinstance(by, str): + var_list = [by] + elif isinstance(by, list): + var_list = by + elif isinstance(by, dict): + var_list = list(by.keys()) + + # not sure why these get appended + var_list = [x for x in var_list if x not in ["newdata", "model"]] + return plot_common(dt, "Slope", var_list) diff --git a/marginaleffects/predictions.py b/marginaleffects/predictions.py index 76e22c5..514ad7f 100644 --- a/marginaleffects/predictions.py +++ b/marginaleffects/predictions.py @@ -29,6 +29,7 @@ def predictions( equivalence=None, transform=None, wts=None, + eps_vcov=None, ): """ Predict outcomes using a fitted model on a specified scale for given combinations of values @@ -161,8 +162,8 @@ def inner(x): out = inner(model.get_coef()) - if vcov is not None: - J = get_jacobian(inner, model.get_coef()) + if V is not None: + J = get_jacobian(inner, model.get_coef(), eps_vcov=eps_vcov) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) out = get_z_p_ci( diff --git a/marginaleffects/slopes.py b/marginaleffects/slopes.py index 3bbb58f..88a395c 100644 --- a/marginaleffects/slopes.py +++ b/marginaleffects/slopes.py @@ -13,6 +13,7 @@ def slopes( equivalence=None, wts=None, eps=1e-4, + eps_vcov=None, ): assert isinstance(eps, float) @@ -31,6 +32,7 @@ def slopes( equivalence=equivalence, wts=wts, eps=eps, + eps_vcov=eps_vcov, ) return out @@ -47,6 +49,7 @@ def avg_slopes( hypothesis=None, equivalence=None, eps=1e-4, + eps_vcov=None, ): if slope not in ["dydx", "eyex", "eydx", "dyex"]: raise ValueError("slope must be one of 'dydx', 'eyex', 'eydx', 'dyex'") @@ -62,6 +65,7 @@ def avg_slopes( hypothesis=hypothesis, equivalence=equivalence, eps=eps, + eps_vcov=eps_vcov, ) return out diff --git a/marginaleffects/uncertainty.py b/marginaleffects/uncertainty.py index 2ded36f..15b1616 100644 --- a/marginaleffects/uncertainty.py +++ b/marginaleffects/uncertainty.py @@ -5,7 +5,7 @@ import scipy.stats as stats -def get_jacobian(func, coefs): +def get_jacobian(func, coefs, eps_vcov=None): # forward finite difference (faster) if coefs.ndim == 2: if isinstance(coefs, np.ndarray): @@ -15,7 +15,10 @@ def get_jacobian(func, coefs): baseline = func(coefs)["estimate"].to_numpy() jac = np.empty((baseline.shape[0], len(coefs_flat)), dtype=np.float64) for i, xi in enumerate(coefs_flat): - h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10) + if eps_vcov is not None: + h = eps_vcov + else: + h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10) dx = np.copy(coefs_flat) dx[i] = dx[i] + h tmp = dx.reshape(coefs.shape) @@ -25,7 +28,10 @@ def get_jacobian(func, coefs): baseline = func(coefs)["estimate"].to_numpy() jac = np.empty((baseline.shape[0], len(coefs)), dtype=np.float64) for i, xi in enumerate(coefs): - h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10) + if eps_vcov is not None: + h = eps_vcov + else: + h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10) dx = np.copy(coefs) dx[i] = dx[i] + h jac[:, i] = (func(dx)["estimate"].to_numpy() - baseline) / h diff --git a/tests/images/plot_comparisons/Figure_2.png b/tests/images/plot_comparisons/Figure_2.png index 4449d16..5d7dd6a 100644 Binary files a/tests/images/plot_comparisons/Figure_2.png and b/tests/images/plot_comparisons/Figure_2.png differ diff --git a/tests/images/plot_comparisons/Figure_3.png b/tests/images/plot_comparisons/Figure_3.png index d4ff0f5..ff9a4e5 100644 Binary files a/tests/images/plot_comparisons/Figure_3.png and b/tests/images/plot_comparisons/Figure_3.png differ diff --git a/tests/images/plot_comparisons/Figure_4.png b/tests/images/plot_comparisons/Figure_4.png index 0a03e78..6248853 100644 Binary files a/tests/images/plot_comparisons/Figure_4.png and b/tests/images/plot_comparisons/Figure_4.png differ diff --git a/tests/images/plot_comparisons/Figure_5.png b/tests/images/plot_comparisons/Figure_5.png index 49766b0..137a6ff 100644 Binary files a/tests/images/plot_comparisons/Figure_5.png and b/tests/images/plot_comparisons/Figure_5.png differ diff --git a/tests/images/plot_comparisons/Figure_6.png b/tests/images/plot_comparisons/Figure_6.png index bd47a73..7ec25df 100644 Binary files a/tests/images/plot_comparisons/Figure_6.png and b/tests/images/plot_comparisons/Figure_6.png differ diff --git a/tests/images/plot_predictions/Figure_2.png b/tests/images/plot_predictions/Figure_2.png deleted file mode 100644 index 3967fd4..0000000 Binary files a/tests/images/plot_predictions/Figure_2.png and /dev/null differ diff --git a/tests/images/plot_predictions/Figure_3.png b/tests/images/plot_predictions/Figure_3.png deleted file mode 100644 index f019092..0000000 Binary files a/tests/images/plot_predictions/Figure_3.png and /dev/null differ diff --git a/tests/images/plot_predictions/Figure_4.png b/tests/images/plot_predictions/Figure_4.png deleted file mode 100644 index 0fba868..0000000 Binary files a/tests/images/plot_predictions/Figure_4.png and /dev/null differ diff --git a/tests/images/plot_predictions/Figure_5.png b/tests/images/plot_predictions/Figure_5.png deleted file mode 100644 index 40c9ce6..0000000 Binary files a/tests/images/plot_predictions/Figure_5.png and /dev/null differ diff --git a/tests/images/plot_predictions/Figure_6.png b/tests/images/plot_predictions/Figure_6.png deleted file mode 100644 index 143a487..0000000 Binary files a/tests/images/plot_predictions/Figure_6.png and /dev/null differ diff --git a/tests/images/plot_predictions/Figure_1.png b/tests/images/plot_predictions/by_01.png similarity index 62% rename from tests/images/plot_predictions/Figure_1.png rename to tests/images/plot_predictions/by_01.png index d2cbd68..665a538 100644 Binary files a/tests/images/plot_predictions/Figure_1.png and b/tests/images/plot_predictions/by_01.png differ diff --git a/tests/images/plot_predictions/by_02.png b/tests/images/plot_predictions/by_02.png new file mode 100644 index 0000000..6407398 Binary files /dev/null and b/tests/images/plot_predictions/by_02.png differ diff --git a/tests/images/plot_predictions/condition_01.png b/tests/images/plot_predictions/condition_01.png new file mode 100644 index 0000000..62e049b Binary files /dev/null and b/tests/images/plot_predictions/condition_01.png differ diff --git a/tests/images/plot_predictions/condition_02.png b/tests/images/plot_predictions/condition_02.png new file mode 100644 index 0000000..610a80c Binary files /dev/null and b/tests/images/plot_predictions/condition_02.png differ diff --git a/tests/images/plot_predictions/issue_57_01.png b/tests/images/plot_predictions/issue_57_01.png new file mode 100644 index 0000000..08f322c Binary files /dev/null and b/tests/images/plot_predictions/issue_57_01.png differ diff --git a/tests/images/plot_predictions/issue_57_02.png b/tests/images/plot_predictions/issue_57_02.png new file mode 100644 index 0000000..d54a594 Binary files /dev/null and b/tests/images/plot_predictions/issue_57_02.png differ diff --git a/tests/images/plot_predictions/issue_57_03.png b/tests/images/plot_predictions/issue_57_03.png new file mode 100644 index 0000000..8370950 Binary files /dev/null and b/tests/images/plot_predictions/issue_57_03.png differ diff --git a/tests/images/plot_predictions/issue_57_04.png b/tests/images/plot_predictions/issue_57_04.png new file mode 100644 index 0000000..f5a988d Binary files /dev/null and b/tests/images/plot_predictions/issue_57_04.png differ diff --git a/tests/images/plot_slopes/Figure_1.png b/tests/images/plot_slopes/Figure_1.png deleted file mode 100644 index 6615a88..0000000 Binary files a/tests/images/plot_slopes/Figure_1.png and /dev/null differ diff --git a/tests/images/plot_slopes/Figure_2.png b/tests/images/plot_slopes/Figure_2.png deleted file mode 100644 index eb7c111..0000000 Binary files a/tests/images/plot_slopes/Figure_2.png and /dev/null differ diff --git a/tests/images/plot_slopes/Figure_3.png b/tests/images/plot_slopes/Figure_3.png deleted file mode 100644 index 3e971cb..0000000 Binary files a/tests/images/plot_slopes/Figure_3.png and /dev/null differ diff --git a/tests/images/plot_slopes/Figure_4.png b/tests/images/plot_slopes/Figure_4.png deleted file mode 100644 index 32074f9..0000000 Binary files a/tests/images/plot_slopes/Figure_4.png and /dev/null differ diff --git a/tests/images/plot_slopes/Figure_5.png b/tests/images/plot_slopes/Figure_5.png deleted file mode 100644 index c539c16..0000000 Binary files a/tests/images/plot_slopes/Figure_5.png and /dev/null differ diff --git a/tests/images/plot_slopes/Figure_6.png b/tests/images/plot_slopes/Figure_6.png deleted file mode 100644 index 233bedc..0000000 Binary files a/tests/images/plot_slopes/Figure_6.png and /dev/null differ diff --git a/tests/images/plot_slopes/by_01.png b/tests/images/plot_slopes/by_01.png new file mode 100644 index 0000000..9eebec4 Binary files /dev/null and b/tests/images/plot_slopes/by_01.png differ diff --git a/tests/images/plot_slopes/by_02.png b/tests/images/plot_slopes/by_02.png new file mode 100644 index 0000000..a8990a4 Binary files /dev/null and b/tests/images/plot_slopes/by_02.png differ diff --git a/tests/images/plot_slopes/condition_01.png b/tests/images/plot_slopes/condition_01.png new file mode 100644 index 0000000..8f2773c Binary files /dev/null and b/tests/images/plot_slopes/condition_01.png differ diff --git a/tests/images/plot_slopes/condition_02.png b/tests/images/plot_slopes/condition_02.png new file mode 100644 index 0000000..dc5e909 Binary files /dev/null and b/tests/images/plot_slopes/condition_02.png differ diff --git a/tests/images/plot_slopes/condition_03.png b/tests/images/plot_slopes/condition_03.png new file mode 100644 index 0000000..3445b18 Binary files /dev/null and b/tests/images/plot_slopes/condition_03.png differ diff --git a/tests/images/plot_slopes/condition_04.png b/tests/images/plot_slopes/condition_04.png new file mode 100644 index 0000000..1befba7 Binary files /dev/null and b/tests/images/plot_slopes/condition_04.png differ diff --git a/tests/test_plot_comparisons.py b/tests/test_plot_comparisons.py index 7cfdaf5..91068b4 100644 --- a/tests/test_plot_comparisons.py +++ b/tests/test_plot_comparisons.py @@ -3,106 +3,48 @@ import polars as pl import pytest import statsmodels.formula.api as smf -from matplotlib.testing.compare import compare_images from marginaleffects import * from marginaleffects.plot_comparisons import * from .utilities import * + + df = pl.read_csv( "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv", - null_values="NA", -).drop_nulls() + null_values="NA",) \ + .drop_nulls() \ + .sort(pl.col("species")) mod = smf.ols( "body_mass_g ~ flipper_length_mm * species * bill_length_mm + island", df.to_pandas(), ).fit() -# @pytest.mark.skip(reason="statsmodels vcov is weird") def test_plot_comparisons(): - tolerance = 50 - - baseline_path = "./tests/images/plot_comparisons/" - - result_path = "./tests/images/.tmp_plot_comparisons/" - if os.path.isdir(result_path): - for root, dirs, files in os.walk(result_path): - for fname in files: - os.remove(os.path.join(root, fname)) - os.rmdir(result_path) - os.mkdir(result_path) - fig = plot_comparisons(mod, variables="species", by="island") - fig.savefig(result_path + "Figure_1.png") - assert ( - compare_images( - baseline_path + "Figure_1.png", result_path + "Figure_1.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_1.png") + assert assert_image(fig, "Figure_1", "plot_comparisons") is None fig = plot_comparisons( mod, variables="bill_length_mm", - newdata=datagrid(mod, bill_length_mm=[37, 39]), by="island", ) - fig.savefig(result_path + "Figure_2.png") - assert ( - compare_images( - baseline_path + "Figure_2.png", result_path + "Figure_2.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_2.png") + assert assert_image(fig, "Figure_2", "plot_comparisons") is None fig = plot_comparisons( mod, variables="bill_length_mm", condition=["flipper_length_mm", "species"] ) - fig.savefig(result_path + "Figure_3.png") - assert ( - compare_images( - baseline_path + "Figure_3.png", result_path + "Figure_3.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_3.png") + assert assert_image(fig, "Figure_3", "plot_comparisons") is None fig = plot_comparisons(mod, variables="species", condition="bill_length_mm") - fig.savefig(result_path + "Figure_4.png") - assert ( - compare_images( - baseline_path + "Figure_4.png", result_path + "Figure_4.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_4.png") + assert assert_image(fig, "Figure_4", "plot_comparisons") is None - fig = plot_comparisons(mod, variables="island", condition="bill_length_mm") - fig.savefig(result_path + "Figure_5.png") - assert ( - compare_images( - baseline_path + "Figure_5.png", result_path + "Figure_5.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_5.png") + fig = plot_comparisons(mod, variables="bill_length_mm", condition="species") + assert assert_image(fig, "Figure_5", "plot_comparisons") is None fig = plot_comparisons( mod, variables="species", condition=["bill_length_mm", "species", "island"] ) - fig.savefig(result_path + "Figure_6.png") - assert ( - compare_images( - baseline_path + "Figure_6.png", result_path + "Figure_6.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_6.png") - - os.rmdir(result_path) - - return + assert assert_image(fig, "Figure_6", "plot_comparisons") is None \ No newline at end of file diff --git a/tests/test_plot_predictions.py b/tests/test_plot_predictions.py index 78c05a2..60029fc 100644 --- a/tests/test_plot_predictions.py +++ b/tests/test_plot_predictions.py @@ -1,110 +1,65 @@ import os import polars as pl -import pytest import statsmodels.formula.api as smf from matplotlib.testing.compare import compare_images - from marginaleffects import * from marginaleffects.plot_predictions import * - from .utilities import * -df = pl.read_csv( +penguins = pl.read_csv( "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv", null_values="NA", ).drop_nulls() -mod = smf.ols( - "body_mass_g ~ flipper_length_mm * species * bill_length_mm + island", - df.to_pandas(), -).fit() -# @pytest.mark.skip(reason="statsmodels vcov is weird") -def test_plot_predictions(): - tolerance = 50 +def test_by(): + mod = smf.ols( + "body_mass_g ~ flipper_length_mm * species * bill_length_mm + island", + penguins.to_pandas(), + ).fit() - baseline_path = "./tests/images/plot_predictions/" + fig = plot_predictions(mod, by="species") + assert assert_image(fig, "by_01", "plot_predictions") is None - result_path = "./tests/images/.tmp_plot_predictions/" - if os.path.isdir(result_path): - for root, dirs, files in os.walk(result_path): - for fname in files: - os.remove(os.path.join(root, fname)) - os.rmdir(result_path) - os.mkdir(result_path) + fig = plot_predictions(mod, by=["species", "island"]) + assert assert_image(fig, "by_02", "plot_predictions") is None - fig = plot_predictions(mod, by="species") - fig.savefig(result_path + "Figure_1.png") - assert ( - compare_images( - baseline_path + "Figure_1.png", result_path + "Figure_1.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_1.png") - fig = plot_predictions( - mod, by="bill_length_mm", newdata=datagrid(model=mod, bill_length_mm=[37, 39]) - ) - fig.savefig(result_path + "Figure_2.png") - assert ( - compare_images( - baseline_path + "Figure_2.png", result_path + "Figure_2.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_2.png") - fig = plot_predictions( - mod, - by=["bill_length_mm", "island", "species"], - newdata=datagrid( - model=mod, bill_length_mm=[72, 431], species=["Adelie", "Chinstrap", "Gentoo"] - ), - ) - fig.savefig(result_path + "Figure_3.png") - assert ( - compare_images( - baseline_path + "Figure_3.png", result_path + "Figure_3.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_3.png") - - fig = plot_predictions(mod, condition="bill_length_mm") - fig.savefig(result_path + "Figure_4.png") - assert ( - compare_images( - baseline_path + "Figure_4.png", result_path + "Figure_4.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_4.png") +def test_condition(): + mod = smf.ols( + "body_mass_g ~ flipper_length_mm * species * bill_length_mm + island", + penguins.to_pandas(), + ).fit() fig = plot_predictions( mod, - condition={"flipper_length_mm": [i for i in range(180, 220)], "species": None}, - ) - fig.savefig(result_path + "Figure_5.png") - assert ( - compare_images( - baseline_path + "Figure_5.png", result_path + "Figure_5.png", tolerance - ) - is None + condition={"flipper_length_mm": list(range(180, 220)), "species": None}, ) - os.remove(result_path + "Figure_5.png") + assert assert_image(fig, "condition_01", "plot_predictions") is None fig = plot_predictions(mod, condition=["bill_length_mm", "species", "island"]) - fig.savefig(result_path + "Figure_6.png") - assert ( - compare_images( - baseline_path + "Figure_6.png", result_path + "Figure_6.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_6.png") + assert assert_image(fig, "condition_02", "plot_predictions") is None + + + +def test_issue_57(): + mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv") + mod = smf.ols("mpg ~ wt + am + qsec", mtcars.to_pandas()).fit() + + fig = plot_predictions(mod, condition=["qsec", "am"]) + assert assert_image(fig, "issue_57_01", "plot_predictions") is None + + fig = plot_predictions(mod, condition={ + "am": None, + "qsec": [mtcars["qsec"].min(), mtcars["qsec"].max()], + }) + assert assert_image(fig, "issue_57_02", "plot_predictions") is None + + fig = plot_predictions(mod, condition={"wt": None, "am": None}) + assert assert_image(fig, "issue_57_03", "plot_predictions") is None - os.rmdir(result_path) + fig = plot_predictions(mod, condition={"am": None, "wt": None}) + assert assert_image(fig, "issue_57_04", "plot_predictions") is None - return diff --git a/tests/test_plot_slopes.py b/tests/test_plot_slopes.py index 1054102..497ee8e 100644 --- a/tests/test_plot_slopes.py +++ b/tests/test_plot_slopes.py @@ -20,72 +20,25 @@ ).fit() -def test_plot_slopes(): - tolerance = 50 - - baseline_path = "./tests/images/plot_slopes/" - - result_path = "./tests/images/.tmp_plot_slopes/" - if os.path.isdir(result_path): - for root, dirs, files in os.walk(result_path): - for fname in files: - os.remove(os.path.join(root, fname)) - os.rmdir(result_path) - os.mkdir(result_path) - +def test_by(): fig = plot_slopes(mod, variables="species", by="island") - fig.savefig(result_path + "Figure_1.png") - assert ( - compare_images( - baseline_path + "Figure_1.png", result_path + "Figure_1.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_1.png") + assert assert_image(fig, "by_01", "plot_slopes") is None - # Tese two tests are failling + fig = plot_slopes(mod, variables='bill_length_mm', by=['species', 'island']) + assert assert_image(fig, "by_02", "plot_slopes") is None - # fig = plot_slopes(mod, variables='bill_length_mm', newdata=datagrid(mod, bill_length_mm=[37,39]), by='island') - # fig.savefig(result_path + "Figure_2.png") - # assert compare_images(baseline_path + "Figure_2.png", result_path + "Figure_2.png", tolerance) is None - # os.remove(result_path + "Figure_2.png") - # fig = plot_slopes(mod, variables='bill_length_mm', condition=['flipper_length_mm', 'species']) - # fig.savefig(result_path + "Figure_3.png") - # assert compare_images(baseline_path + "Figure_3.png", result_path + "Figure_3.png", tolerance) is None - # os.remove(result_path + "Figure_3.png") +def test_condition(): + fig = plot_slopes(mod, variables='bill_length_mm', condition=['flipper_length_mm', 'species'], eps_vcov=1e-2) + assert assert_image(fig, "condition_01", "plot_slopes") is None fig = plot_slopes(mod, variables="species", condition="bill_length_mm") - fig.savefig(result_path + "Figure_4.png") - assert ( - compare_images( - baseline_path + "Figure_4.png", result_path + "Figure_4.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_4.png") + assert assert_image(fig, "condition_02", "plot_slopes") is None - fig = plot_slopes(mod, variables="island", condition="bill_length_mm") - fig.savefig(result_path + "Figure_5.png") - assert ( - compare_images( - baseline_path + "Figure_5.png", result_path + "Figure_5.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_5.png") + fig = plot_slopes(mod, variables="island", condition="bill_length_mm", eps = 1e-2) + assert assert_image(fig, "condition_03", "plot_slopes") is None fig = plot_slopes( mod, variables="species", condition=["bill_length_mm", "species", "island"] ) - fig.savefig(result_path + "Figure_6.png") - assert ( - compare_images( - baseline_path + "Figure_6.png", result_path + "Figure_6.png", tolerance - ) - is None - ) - os.remove(result_path + "Figure_6.png") - - os.rmdir(result_path) - return + assert assert_image(fig, "condition_04", "plot_slopes") is None \ No newline at end of file diff --git a/tests/test_predictions.py b/tests/test_predictions.py index 6d9d622..4230aa6 100644 --- a/tests/test_predictions.py +++ b/tests/test_predictions.py @@ -49,3 +49,8 @@ def issue_38(): assert p.shape[0] == 1 p = avg_predictions(mod_py) assert p.shape[0] == 1 + +def issue_59(): + p = predictions(mod_py, vcov = False) + assert p.shape[0] == df.shape[0] + assert p.shape[1] > 20 \ No newline at end of file diff --git a/tests/utilities.py b/tests/utilities.py index c8462e3..f62006d 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,4 +1,6 @@ +import os import re +from matplotlib.testing.compare import compare_images from marginaleffects import * @@ -20,3 +22,24 @@ def compare_r_to_py(r_obj, py_obj, tolr=1e-3, tola=1e-3, msg=""): gap_abs = (a - b).abs().max() flag = gap_rel <= tolr or gap_abs <= tola assert flag, f"{msg} trel: {gap_rel}. tabs: {gap_abs}" + + +def assert_image(fig, label, file, tolerance=5): + known_path = f"./tests/images/{file}/" + unknown_path = f"./tests/images/.tmp_{file}/" + if os.path.isdir(unknown_path): + for root, dirs, files in os.walk(unknown_path): + for fname in files: + os.remove(os.path.join(root, fname)) + os.rmdir(unknown_path) + os.mkdir(unknown_path) + unknown = f"{unknown_path}{label}.png" + known = f"{known_path}{label}.png" + if not os.path.exists(known): + fig.savefig(known) + raise FileExistsError(f"File {known} does not exist. Creating it now.") + fig.savefig(unknown) + out = compare_images(known, unknown, tol=tolerance) + # os.remove(unknown) + return out +