From 6188c4054e5a3593e1dbabf68a011b20dfbc2ade Mon Sep 17 00:00:00 2001 From: Vincent Arel-Bundock Date: Fri, 22 Dec 2023 09:50:23 -0500 Subject: [PATCH] ruff format --- marginaleffects/__init__.py | 8 -- marginaleffects/by.py | 10 +- marginaleffects/classes.py | 4 +- marginaleffects/comparisons.py | 17 ++- marginaleffects/datagrid.py | 15 +- marginaleffects/estimands.py | 2 +- marginaleffects/getters.py | 9 +- marginaleffects/hypotheses.py | 49 +++---- marginaleffects/hypothesis.py | 1 - marginaleffects/plot_common.py | 158 ++++++++++++++------- marginaleffects/plot_comparisons.py | 98 +++++++------ marginaleffects/plot_predictions.py | 70 +++++---- marginaleffects/plot_slopes.py | 54 +++---- marginaleffects/predictions.py | 29 ++-- marginaleffects/sanity.py | 43 ++++-- marginaleffects/uncertainty.py | 6 +- marginaleffects/utils.py | 2 - tests/test_by.py | 2 - tests/test_comparisons.py | 2 - tests/test_datagrid.py | 1 - tests/test_slopes.py | 2 - tests/test_statsmodels_mixedlm.py | 2 - tests/test_statsmodels_negativebinomial.py | 1 - tests/test_statsmodels_wls.py | 1 - tests/utilities.py | 2 - 25 files changed, 336 insertions(+), 252 deletions(-) diff --git a/marginaleffects/__init__.py b/marginaleffects/__init__.py index c6ab264..e69de29 100644 --- a/marginaleffects/__init__.py +++ b/marginaleffects/__init__.py @@ -1,8 +0,0 @@ -from .comparisons import avg_comparisons, comparisons -from .datagrid import datagrid, datagridcf -from .hypotheses import hypotheses -from .plot_comparisons import plot_comparisons -from .plot_predictions import plot_predictions -from .plot_slopes import plot_slopes -from .predictions import avg_predictions, predictions -from .slopes import avg_slopes, slopes diff --git a/marginaleffects/by.py b/marginaleffects/by.py index ee843dd..0df522d 100644 --- a/marginaleffects/by.py +++ b/marginaleffects/by.py @@ -2,9 +2,13 @@ def get_by(model, estimand, newdata, by=None, wts=None): - # for predictions - if isinstance(by, list) and len(by) == 1 and by[0] == "group" and "group" not in estimand.columns: + if ( + isinstance(by, list) + and len(by) == 1 + and by[0] == "group" + and "group" not in estimand.columns + ): by = True if by is True: @@ -20,12 +24,10 @@ def get_by(model, estimand, newdata, by=None, wts=None): else: out = pl.DataFrame({"estimate": estimand["estimate"]}) - by = [x for x in by if x in out.columns] if isinstance(by, list) and len(by) == 0: return out - if wts is None: out = out.groupby(by, maintain_order=True).agg(pl.col("estimate").mean()) else: diff --git a/marginaleffects/classes.py b/marginaleffects/classes.py index 96ca992..ba4cab8 100644 --- a/marginaleffects/classes.py +++ b/marginaleffects/classes.py @@ -22,7 +22,7 @@ def __str__(self): "std_error": "Std.Error", "statistic": "z", "p_value": "P(>|z|)", - "s_value": "S" + "s_value": "S", } if hasattr(self, "conf_level"): @@ -49,7 +49,7 @@ def __str__(self): valid = list(mapping.keys()) valid = self.datagrid_explicit + valid - + valid = [x for x in valid if x in self.columns] mapping = {key: mapping[key] for key in mapping if key in valid} tmp = self.select(valid).rename(mapping) diff --git a/marginaleffects/comparisons.py b/marginaleffects/comparisons.py index 6190ece..6d0ae85 100644 --- a/marginaleffects/comparisons.py +++ b/marginaleffects/comparisons.py @@ -10,8 +10,13 @@ from .estimands import estimands from .getters import get_coef, get_modeldata, get_predict from .hypothesis import get_hypothesis -from .sanity import (sanitize_by, sanitize_hypothesis_null, sanitize_newdata, - sanitize_variables, sanitize_vcov) +from .sanity import ( + sanitize_by, + sanitize_hypothesis_null, + sanitize_newdata, + sanitize_variables, + sanitize_vcov, +) from .transform import get_transform from .uncertainty import get_jacobian, get_se, get_z_p_ci from .utils import get_pad, sort_columns, upcast @@ -190,9 +195,7 @@ def inner(coefs, by, hypothesis, wts, nd): # estimates tmp = [ - get_predict(model, get_coef(model), nd_X).rename( - {"estimate": "predicted"} - ), + get_predict(model, get_coef(model), nd_X).rename({"estimate": "predicted"}), get_predict(model, coefs, lo_X) .rename({"estimate": "predicted_lo"}) .select("predicted_lo"), @@ -270,7 +273,9 @@ def outer(x): J = get_jacobian(func=outer, coefs=get_coef(model)) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) - out = get_z_p_ci(out, model, conf_level=conf_level, hypothesis_null=hypothesis_null) + out = get_z_p_ci( + out, model, conf_level=conf_level, hypothesis_null=hypothesis_null + ) out = get_transform(out, transform=transform) out = get_equivalence(out, equivalence=equivalence, df=np.inf) diff --git a/marginaleffects/datagrid.py b/marginaleffects/datagrid.py index ec90e93..293a58e 100644 --- a/marginaleffects/datagrid.py +++ b/marginaleffects/datagrid.py @@ -10,7 +10,7 @@ def datagrid( newdata=None, FUN_numeric=lambda x: x.mean(), FUN_other=lambda x: x.mode()[0], # mode can return multiple values - **kwargs + **kwargs, ): """ Data grids @@ -95,7 +95,6 @@ def datagrid( return out - def datagridcf(model=None, newdata=None, **kwargs): """ Data grids @@ -122,20 +121,22 @@ def datagridcf(model=None, newdata=None, **kwargs): newdata = get_modeldata(model) if "rowid" not in newdata.columns: - newdata = newdata.with_columns(pl.Series(range(newdata.shape[0])).alias("rowid")) - newdata = newdata.rename({"rowid" : "rowidcf"}) + newdata = newdata.with_columns( + pl.Series(range(newdata.shape[0])).alias("rowid") + ) + newdata = newdata.rename({"rowid": "rowidcf"}) # Create dataframe from kwargs dfs = [pl.DataFrame({k: v}) for k, v in kwargs.items()] # Perform cross join - df_cross = reduce(lambda df1, df2: df1.join(df2, how='cross'), dfs) + df_cross = reduce(lambda df1, df2: df1.join(df2, how="cross"), dfs) # Drop would-be duplicates newdata = newdata.drop(df_cross.columns) - result = newdata.join(df_cross, how = "cross") + result = newdata.join(df_cross, how="cross") result.datagrid_explicit = list(kwargs.keys()) - return result \ No newline at end of file + return result diff --git a/marginaleffects/estimands.py b/marginaleffects/estimands.py index 37630ae..7860a28 100644 --- a/marginaleffects/estimands.py +++ b/marginaleffects/estimands.py @@ -26,7 +26,7 @@ def prep(x): "dydxavg": lambda hi, lo, eps, x, y, w: prep(((hi - lo) / eps).mean()), "eyexavg": lambda hi, lo, eps, x, y, w: prep(((hi - lo) / eps * (x / y)).mean()), "eydxavg": lambda hi, lo, eps, x, y, w: prep((((hi - lo) / eps) / y).mean()), - "dyexavg": lambda hi, lo, eps, x, y, w: prep(((((hi - lo) / eps) * x)).mean()), + "dyexavg": lambda hi, lo, eps, x, y, w: prep((((hi - lo) / eps) * x).mean()), "dydxavgwts": lambda hi, lo, eps, x, y, w: prep( (((hi - lo) / eps) * w).sum() / w.sum() ), diff --git a/marginaleffects/getters.py b/marginaleffects/getters.py index 7227279..1d670dc 100644 --- a/marginaleffects/getters.py +++ b/marginaleffects/getters.py @@ -22,7 +22,6 @@ def get_coef(model): def get_vcov(model, vcov=True): - if isinstance(vcov, bool): if vcov is True: V = model.cov_params() @@ -37,8 +36,10 @@ def get_vcov(model, vcov=True): raise ValueError(f"The model object has no {lab} attribute.") else: - raise ValueError('`vcov` must be a boolean or a string like "HC3", which corresponds to an attribute of the model object such as "vcov_HC3".') - + raise ValueError( + '`vcov` must be a boolean or a string like "HC3", which corresponds to an attribute of the model object such as "vcov_HC3".' + ) + V = np.array(V) return V @@ -87,4 +88,4 @@ def get_predict(model, params, newdata: pl.DataFrame): "The `predict()` method must return an array with 1 or 2 dimensions." ) p = p.with_columns(pl.col("rowid").cast(pl.Int32)) - return p \ No newline at end of file + return p diff --git a/marginaleffects/hypotheses.py b/marginaleffects/hypotheses.py index 4e78ac3..c3a0a2e 100644 --- a/marginaleffects/hypotheses.py +++ b/marginaleffects/hypotheses.py @@ -1,4 +1,3 @@ -import numpy as np import polars as pl from .classes import MarginaleffectsDataFrame @@ -13,27 +12,27 @@ def hypotheses(model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=None): """ (Non-)Linear Tests for Null Hypotheses, Joint Hypotheses, Equivalence, Non Superiority, and Non Inferiority. - - This function calculates uncertainty estimates as first-order approximate standard errors for linear or non-linear - functions of a vector of random variables with known or estimated covariance matrix. It emulates the behavior of - the excellent and well-established `car::deltaMethod` and `car::linearHypothesis` functions in R, but it supports - more models; requires fewer dependencies; expands the range of tests to equivalence and superiority/inferiority; + + This function calculates uncertainty estimates as first-order approximate standard errors for linear or non-linear + functions of a vector of random variables with known or estimated covariance matrix. It emulates the behavior of + the excellent and well-established `car::deltaMethod` and `car::linearHypothesis` functions in R, but it supports + more models; requires fewer dependencies; expands the range of tests to equivalence and superiority/inferiority; and offers convenience features like robust standard errors. - + To learn more, visit the package website: - - Warning #1: Tests are conducted directly on the scale defined by the `type` argument. For some models, it can make - sense to conduct hypothesis or equivalence tests on the `"link"` scale instead of the `"response"` scale which is + + Warning #1: Tests are conducted directly on the scale defined by the `type` argument. For some models, it can make + sense to conduct hypothesis or equivalence tests on the `"link"` scale instead of the `"response"` scale which is often the default. - - Warning #2: For hypothesis tests on objects produced by the `marginaleffects` package, it is safer to use the - `hypothesis` argument of the original function. Using `hypotheses()` may not work in certain environments, in lists, + + Warning #2: For hypothesis tests on objects produced by the `marginaleffects` package, it is safer to use the + `hypothesis` argument of the original function. Using `hypotheses()` may not work in certain environments, in lists, or when working programmatically with *apply style functions. - - Warning #3: The tests assume that the `hypothesis` expression is (approximately) normally distributed, which for - non-linear functions of the parameters may not be realistic. More reliable confidence intervals can be obtained using + + Warning #3: The tests assume that the `hypothesis` expression is (approximately) normally distributed, which for + non-linear functions of the parameters may not be realistic. More reliable confidence intervals can be obtained using the `inferences()` function with `method = "boot"`. - + Parameters: model : object Model object estimated by `statsmodels` @@ -45,11 +44,11 @@ def hypotheses(model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=N Whether to use the covariance matrix in the hypothesis test. Default is True. equivalence : tuple, optional The equivalence range for the hypothesis test. Default is None. - + Returns: MarginaleffectsDataFrame A DataFrame containing the results of the hypothesis tests. - + Examples: # When `hypothesis` is `None`, `hypotheses()` returns a DataFrame of parameters @@ -57,16 +56,16 @@ def hypotheses(model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=N # A different null hypothesis hypotheses(model, hypothesis = 3) - + # Test of equality between coefficients hypotheses(model, hypothesis="param1 = param2") - + # Non-linear function hypotheses(model, hypothesis="exp(param1 + param2) = 0.1") - + # Robust standard errors hypotheses(model, hypothesis="param1 = param2", vcov="HC3") - + # Equivalence, non-inferiority, and non-superiority tests hypotheses(model, equivalence=(0, 10)) """ @@ -86,7 +85,9 @@ def fun(x): J = get_jacobian(fun, get_coef(model)) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) - out = get_z_p_ci(out, model, conf_level=conf_level, hypothesis_null=hypothesis_null) + out = get_z_p_ci( + out, model, conf_level=conf_level, hypothesis_null=hypothesis_null + ) out = get_equivalence(out, equivalence=equivalence) out = sort_columns(out, by=None) out = MarginaleffectsDataFrame(out, conf_level=conf_level) diff --git a/marginaleffects/hypothesis.py b/marginaleffects/hypothesis.py index fa8d30c..a86e173 100644 --- a/marginaleffects/hypothesis.py +++ b/marginaleffects/hypothesis.py @@ -71,7 +71,6 @@ def get_hypothesis(x, hypothesis): elif hypothesis == "revpairwise": hypmat = lincom_revpairwise(x) else: - raise ValueError(msg) out = lincom_multiply(x, hypmat.to_numpy()) out = out.with_columns(pl.Series(hypothesis.columns).alias("term")) diff --git a/marginaleffects/plot_common.py b/marginaleffects/plot_common.py index 6b9185e..7ba19e0 100644 --- a/marginaleffects/plot_common.py +++ b/marginaleffects/plot_common.py @@ -3,58 +3,63 @@ import polars as pl from matplotlib.lines import Line2D -from .datagrid import datagrid from .getters import get_modeldata from .utils import get_variable_type def dt_on_condition(model, condition): - - modeldata = get_modeldata(model) if isinstance(condition, str): condition = [condition] - assert 1 <= len(condition) <= 3, f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}." - + assert ( + 1 <= len(condition) <= 3 + ), f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}." to_datagrid = {} first_key = "" # special case when the first element is numeric if isinstance(condition, list): - assert all(ele in modeldata.columns for ele in condition), "All elements of condition must be columns of the model." + assert all( + ele in modeldata.columns for ele in condition + ), "All elements of condition must be columns of the model." first_key = condition[0] to_datagrid = {key: None for key in condition} elif isinstance(condition, dict): - assert all(key in modeldata.columns for key in condition.keys()), "All keys of condition must be columns of the model." + assert all( + key in modeldata.columns for key in condition.keys() + ), "All keys of condition must be columns of the model." first_key = next(iter(condition)) - to_datagrid = condition - - + to_datagrid = condition for key, value in to_datagrid.items(): - variable_type = get_variable_type(key, modeldata) # Check type of user-supplied dict values if value is not None: test_df = pl.DataFrame({key: value}) - assert variable_type == get_variable_type(key, test_df), f"Supplied data type of {key} column ({get_variable_type(key, test_df)}) does not match the type of the variable ({variable_type})." + assert ( + variable_type == get_variable_type(key, test_df) + ), f"Supplied data type of {key} column ({get_variable_type(key, test_df)}) does not match the type of the variable ({variable_type})." continue - if variable_type == 'numeric': + if variable_type == "numeric": if key == first_key: - to_datagrid[key] = np.linspace(modeldata[key].min(), modeldata[key].max(), 100).tolist() + to_datagrid[key] = np.linspace( + modeldata[key].min(), modeldata[key].max(), 100 + ).tolist() else: - to_datagrid[key] = np.percentile(modeldata[key], [0, 25, 50, 75, 100], method="midpoint").tolist() + to_datagrid[key] = np.percentile( + modeldata[key], [0, 25, 50, 75, 100], method="midpoint" + ).tolist() - elif variable_type == 'boolean' or variable_type == 'character': + elif variable_type == "boolean" or variable_type == "character": to_datagrid[key] = modeldata[key].unique().to_list() - assert len(to_datagrid[key]) <= 10, f"Character type variables of more than 10 unique values are not supported. {ele} variable has {len(to_datagrid[ele])} unique values." - - + assert ( + len(to_datagrid[key]) <= 10 + ), f"Character type variables of more than 10 unique values are not supported. {ele} variable has {len(to_datagrid[ele])} unique values." dt_code = "datagrid(newdata=modeldata" for key, value in to_datagrid.items(): @@ -71,7 +76,6 @@ def dt_on_condition(model, condition): def plotter(dt, x_name, x_type, fig=None, axe=None, label=None, color=None): - x = dt.select(x_name).to_numpy().flatten() y = dt.select("estimate").to_numpy().flatten() y_low = dt.select("conf_low").to_numpy().flatten() @@ -82,7 +86,7 @@ def plotter(dt, x_name, x_type, fig=None, axe=None, label=None, color=None): else: fig = plt.figure() plot_obj = plt - + if x_type == "numeric": if color is None: plot_obj.fill_between(x, y_low, y_high, alpha=0.2) @@ -95,15 +99,16 @@ def plotter(dt, x_name, x_type, fig=None, axe=None, label=None, color=None): y_low = np.absolute(y - y_low) y_high = np.absolute(y_high - y) if color is None: - plot_obj.errorbar(x, y, yerr=(y_low, y_high), fmt='o', label=label) + plot_obj.errorbar(x, y, yerr=(y_low, y_high), fmt="o", label=label) else: - plot_obj.errorbar(x, y, yerr=(y_low, y_high), fmt='o', color=color, label=label) + plot_obj.errorbar( + x, y, yerr=(y_low, y_high), fmt="o", color=color, label=label + ) return fig def plot_common(dt, y_label, var_list): - titles_fontsize = 12 x_name = var_list[0] @@ -114,23 +119,25 @@ def plot_common(dt, y_label, var_list): subplot = [None] elif len(var_list) == 3: color = var_list[1] - subplot = dt.select(var_list[2]).unique(maintain_order=True).to_numpy().flatten() + subplot = ( + dt.select(var_list[2]).unique(maintain_order=True).to_numpy().flatten() + ) else: color = None subplot = [None] - # when 'contrast' is a column containing more than 1 unique value, we subplot all intersections + # when 'contrast' is a column containing more than 1 unique value, we subplot all intersections # of these values with explicit subplots - if 'contrast' in dt.columns: - contrast = dt.select('contrast').unique(maintain_order=True).to_numpy().flatten() + if "contrast" in dt.columns: + contrast = ( + dt.select("contrast").unique(maintain_order=True).to_numpy().flatten() + ) if len(contrast) == 1: contrast = [None] else: contrast = [None] - if subplot[0] is not None or contrast[0] is not None: - color_i = 0 color_dict = {} @@ -138,25 +145,28 @@ def plot_common(dt, y_label, var_list): dim_max = subplot dim_min = contrast max_name = var_list[2] if len(var_list) == 3 else None - min_name = 'contrast' + min_name = "contrast" else: dim_max = contrast dim_min = subplot - max_name = 'contrast' + max_name = "contrast" min_name = var_list[2] if len(var_list) == 3 else None max_len = len(dim_max) min_len = len(dim_min) - figsize_def = plt.rcParams.get('figure.figsize') - figsize = [max(figsize_def[0], (2/3)*figsize_def[0]*max_len), max(figsize_def[1], (2/3)*figsize_def[1]*min_len)] - - fig, axes = plt.subplots(min_len, max_len, squeeze=False, layout="constrained", figsize=figsize) + figsize_def = plt.rcParams.get("figure.figsize") + figsize = [ + max(figsize_def[0], (2 / 3) * figsize_def[0] * max_len), + max(figsize_def[1], (2 / 3) * figsize_def[1] * min_len), + ] + fig, axes = plt.subplots( + min_len, max_len, squeeze=False, layout="constrained", figsize=figsize + ) for i, dim_min_i in enumerate(dim_min): for j, dim_max_j in enumerate(dim_max): - subplot_dt = dt subplot_dt = subplot_dt.filter(pl.col(max_name) == dim_max_j) @@ -165,52 +175,92 @@ def plot_common(dt, y_label, var_list): axe = max_len * i + j - if color is None: plotter(subplot_dt, x_name, x_type, fig=fig, axe=axe) else: - - for color_val in subplot_dt.select(color).unique(maintain_order=True).to_numpy().flatten(): + for color_val in ( + subplot_dt.select(color) + .unique(maintain_order=True) + .to_numpy() + .flatten() + ): if color_val not in color_dict: - color_dict[color_val] = plt.rcParams['axes.prop_cycle'].by_key()['color'][color_i] + color_dict[color_val] = plt.rcParams[ + "axes.prop_cycle" + ].by_key()["color"][color_i] color_i += 1 color_dt = subplot_dt.filter(pl.col(color) == color_val) - plotter(color_dt, x_name, x_type, fig=fig, axe=axe, label=color_val, color=color_dict[color_val]) - - if max_name == 'contrast': + plotter( + color_dt, + x_name, + x_type, + fig=fig, + axe=axe, + label=color_val, + color=color_dict[color_val], + ) + + if max_name == "contrast": title = dim_min_i if dim_min_i is not None else "" - title += "\n" + subplot_dt.select(pl.first('term')).item() + ", " + dim_max_j if dim_max_j is not None else "" + title += ( + "\n" + + subplot_dt.select(pl.first("term")).item() + + ", " + + dim_max_j + if dim_max_j is not None + else "" + ) else: title = dim_max_j - title += "\n" + subplot_dt.select(pl.first('term')).item() + ", " + dim_min_i if dim_min_i is not None else "" + title += ( + "\n" + + subplot_dt.select(pl.first("term")).item() + + ", " + + dim_min_i + if dim_min_i is not None + else "" + ) fig.axes[axe].set_title(title, fontsize=titles_fontsize) - if color is not None: - legend_elements = [Line2D([0], [0], color=val, label=key) for key,val in color_dict.items()] - + legend_elements = [ + Line2D([0], [0], color=val, label=key) + for key, val in color_dict.items() + ] elif color is not None: fig = plt.figure(layout="constrained") - for color_val in dt.select(color).unique(maintain_order=True).to_numpy().flatten(): + for color_val in ( + dt.select(color).unique(maintain_order=True).to_numpy().flatten() + ): color_dt = dt.filter(pl.col(color) == color_val) plotter(color_dt, x_name, x_type, fig=fig, label=color_val) - fig.legend(loc='outside center right', title=color, fontsize=titles_fontsize, title_fontsize=titles_fontsize) - + fig.legend( + loc="outside center right", + title=color, + fontsize=titles_fontsize, + title_fontsize=titles_fontsize, + ) else: fig = plotter(dt, x_name, x_type) if (subplot[0] is not None or contrast[0] is not None) and color is not None: - fig.legend(handles=legend_elements, loc='outside center right', title=color, fontsize=titles_fontsize, title_fontsize=titles_fontsize) + fig.legend( + handles=legend_elements, + loc="outside center right", + title=color, + fontsize=titles_fontsize, + title_fontsize=titles_fontsize, + ) fig.supxlabel(x_name, fontsize=titles_fontsize) fig.supylabel(y_label, fontsize=titles_fontsize) - return plt \ No newline at end of file + return plt diff --git a/marginaleffects/plot_comparisons.py b/marginaleffects/plot_comparisons.py index a952d5e..37201b6 100644 --- a/marginaleffects/plot_comparisons.py +++ b/marginaleffects/plot_comparisons.py @@ -1,10 +1,5 @@ -import numpy as np -import polars as pl - from .comparisons import comparisons -from .getters import get_modeldata from .plot_common import dt_on_condition, plot_common -from .utils import get_variable_type def plot_comparisons( @@ -64,7 +59,7 @@ def plot_comparisons( Aggregate unit-level estimates (aka, marginalize, average over). newdata : dataframe - When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. + When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. wts: Column name of weights to use for marginalization. Must be a column in `newdata` @@ -75,39 +70,48 @@ def plot_comparisons( draw : True returns a matplotlib plot. False returns a dataframe of the underlying data. """ - assert not (not by and newdata is not None), "The `newdata` argument requires a `by` argument." + assert not ( + not by and newdata is not None + ), "The `newdata` argument requires a `by` argument." - assert (condition is None and by) or (condition is not None and not by), "One of the `condition` and `by` arguments must be supplied, but not both." + assert (condition is None and by) or ( + condition is not None and not by + ), "One of the `condition` and `by` arguments must be supplied, but not both." - assert not (wts is not None and not by), "The `wts` argument requires a `by` argument." + assert not ( + wts is not None and not by + ), "The `wts` argument requires a `by` argument." if by: - if newdata is not None: - dt = comparisons(model, - variables=variables, - newdata=newdata, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps) + dt = comparisons( + model, + variables=variables, + newdata=newdata, + comparison=comparison, + vcov=vcov, + conf_level=conf_level, + by=by, + wts=wts, + hypothesis=hypothesis, + equivalence=equivalence, + transform=transform, + eps=eps, + ) else: - dt = comparisons(model, - variables=variables, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=by, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps) + dt = comparisons( + model, + variables=variables, + comparison=comparison, + vcov=vcov, + conf_level=conf_level, + by=by, + wts=wts, + hypothesis=hypothesis, + equivalence=equivalence, + transform=transform, + eps=eps, + ) var_list = [by] if isinstance(by, str) else by @@ -119,18 +123,20 @@ def plot_comparisons( var_list = condition elif isinstance(condition, dict): var_list = list(condition.keys()) - dt = comparisons(model, - variables=variables, - newdata=dt_condition, - comparison=comparison, - vcov=vcov, - conf_level=conf_level, - by=var_list, - wts=wts, - hypothesis=hypothesis, - equivalence=equivalence, - transform=transform, - eps=eps) + dt = comparisons( + model, + variables=variables, + newdata=dt_condition, + comparison=comparison, + vcov=vcov, + conf_level=conf_level, + by=var_list, + wts=wts, + hypothesis=hypothesis, + equivalence=equivalence, + transform=transform, + eps=eps, + ) dt = dt.drop_nulls(var_list[0]) dt = dt.sort(var_list[0]) @@ -138,4 +144,4 @@ def plot_comparisons( if not draw: return dt - return plot_common(dt, "Comparison", var_list) \ No newline at end of file + return plot_common(dt, "Comparison", var_list) diff --git a/marginaleffects/plot_predictions.py b/marginaleffects/plot_predictions.py index 20ff014..4b15371 100644 --- a/marginaleffects/plot_predictions.py +++ b/marginaleffects/plot_predictions.py @@ -1,6 +1,3 @@ -import numpy as np -import polars as pl - from .getters import find_response from .plot_common import dt_on_condition, plot_common from .predictions import predictions @@ -15,7 +12,7 @@ def plot_predictions( conf_level=0.95, transform=None, draw=True, - wts=None + wts=None, ): """ Plot predictions on the y-axis against values of one or more predictors (x-axis, colors, and facets). @@ -54,7 +51,7 @@ def plot_predictions( Names of the categorical predictors to marginalize across. newdata : dataframe - When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. + When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. wts: Column name of weights to use for marginalization. Must be a column in `newdata` @@ -64,30 +61,39 @@ def plot_predictions( draw : True returns a matplotlib plot. False returns a dataframe of the underlying data. """ - - assert not (not by and newdata is not None), "The `newdata` argument requires a `by` argument." - assert (condition is None and by) or (condition is not None and not by), "One of the `condition` and `by` arguments must be supplied, but not both." + assert not ( + not by and newdata is not None + ), "The `newdata` argument requires a `by` argument." - assert not (wts is not None and not by), "The `wts` argument requires a `by` argument." + assert (condition is None and by) or ( + condition is not None and not by + ), "One of the `condition` and `by` arguments must be supplied, but not both." - if by: + assert not ( + wts is not None and not by + ), "The `wts` argument requires a `by` argument." + if by: if newdata is not None: - dt = predictions(model, - by=by, - newdata=newdata, - conf_level=conf_level, - vcov=vcov, - transform=transform, - wts=wts) + dt = predictions( + model, + by=by, + newdata=newdata, + conf_level=conf_level, + vcov=vcov, + transform=transform, + wts=wts, + ) else: - dt = predictions(model, - by=by, - conf_level=conf_level, - vcov=vcov, - transform=transform, - wts=wts) + dt = predictions( + model, + by=by, + conf_level=conf_level, + vcov=vcov, + transform=transform, + wts=wts, + ) var_list = [by] if isinstance(by, str) else by @@ -99,17 +105,19 @@ def plot_predictions( var_list = condition elif isinstance(condition, dict): var_list = list(condition.keys()) - dt = predictions(model, - by=var_list, - newdata=dt_condition, - conf_level=conf_level, - vcov=vcov, - transform=transform) + dt = predictions( + model, + by=var_list, + newdata=dt_condition, + conf_level=conf_level, + vcov=vcov, + transform=transform, + ) dt = dt.drop_nulls(var_list[0]) dt = dt.sort(var_list[0]) - + if not draw: return dt - return plot_common(dt, find_response(model), var_list) \ No newline at end of file + return plot_common(dt, find_response(model), var_list) diff --git a/marginaleffects/plot_slopes.py b/marginaleffects/plot_slopes.py index 89fd9d8..9838465 100644 --- a/marginaleffects/plot_slopes.py +++ b/marginaleffects/plot_slopes.py @@ -1,10 +1,5 @@ -import numpy as np -import polars as pl - -from .getters import get_modeldata from .plot_common import dt_on_condition, plot_common from .slopes import slopes -from .utils import get_variable_type def plot_slopes( @@ -60,38 +55,47 @@ def plot_slopes( Names of the categorical predictors to marginalize across. newdata : dataframe - When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. + When newdata is NULL, the grid is determined by the condition argument. When newdata is not NULL, the argument behaves in the same way as in the predictions() function. wts: Column name of weights to use for marginalization. Must be a column in `newdata` draw : True returns a matplotlib plot. False returns a dataframe of the underlying data. """ - assert not (not by and newdata is not None), "The `newdata` argument requires a `by` argument." + assert not ( + not by and newdata is not None + ), "The `newdata` argument requires a `by` argument." - assert (condition is None and by) or (condition is not None and not by), "One of the `condition` and `by` arguments must be supplied, but not both." + assert (condition is None and by) or ( + condition is not None and not by + ), "One of the `condition` and `by` arguments must be supplied, but not both." - assert not (wts is not None and not by), "The `wts` argument requires a `by` argument." + assert not ( + wts is not None and not by + ), "The `wts` argument requires a `by` argument." if by: - if newdata is not None: - dt = slopes(model, + dt = slopes( + model, variables=variables, newdata=newdata, slope=slope, vcov=vcov, conf_level=conf_level, by=by, - wts=wts) + wts=wts, + ) else: - dt = slopes(model, + dt = slopes( + model, variables=variables, slope=slope, vcov=vcov, conf_level=conf_level, by=by, - wts=wts) + wts=wts, + ) var_list = [by] if isinstance(by, str) else by @@ -103,19 +107,21 @@ def plot_slopes( var_list = condition elif isinstance(condition, dict): var_list = list(condition.keys()) - dt = slopes(model, - variables=variables, - newdata=dt_condition, - slope=slope, - vcov=vcov, - conf_level=conf_level, - by=var_list, - wts=wts) + dt = slopes( + model, + variables=variables, + newdata=dt_condition, + slope=slope, + vcov=vcov, + conf_level=conf_level, + by=var_list, + wts=wts, + ) dt = dt.drop_nulls(var_list[0]) dt = dt.sort(var_list[0]) - + if not draw: return dt - return plot_common(dt, "Slope", var_list) \ No newline at end of file + return plot_common(dt, "Slope", var_list) diff --git a/marginaleffects/predictions.py b/marginaleffects/predictions.py index dd381a5..ba3b658 100644 --- a/marginaleffects/predictions.py +++ b/marginaleffects/predictions.py @@ -7,8 +7,12 @@ from .equivalence import get_equivalence from .getters import get_coef, get_modeldata, get_predict, get_variables_names from .hypothesis import get_hypothesis -from .sanity import (sanitize_by, sanitize_hypothesis_null, sanitize_newdata, - sanitize_vcov) +from .sanity import ( + sanitize_by, + sanitize_hypothesis_null, + sanitize_newdata, + sanitize_vcov, +) from .transform import get_transform from .uncertainty import get_jacobian, get_se, get_z_p_ci from .utils import get_pad, sort_columns, upcast @@ -91,7 +95,10 @@ def predictions( mean = modeldata[variable].mean() val = [mean - std, mean + std] elif value == "iqr": - val = [np.percentile(newdata[variable], 75), np.percentile(newdata[variable], 25)] + val = [ + np.percentile(newdata[variable], 75), + np.percentile(newdata[variable], 25), + ] elif value == "minmax": val = [np.max(newdata[variable]), np.min(newdata[variable])] elif value == "threenum": @@ -99,19 +106,21 @@ def predictions( mean = modeldata[variable].mean() val = [mean - std / 2, mean, mean + std / 2] elif value == "fivenum": - val = np.percentile(modeldata[variable], [0, 25, 50, 75, 100], method="midpoint") + val = np.percentile( + modeldata[variable], [0, 25, 50, 75, 100], method="midpoint" + ) else: val = value newdata = newdata.drop(variable) - newdata = newdata.join(pl.DataFrame({variable:val}), how = "cross") + newdata = newdata.join(pl.DataFrame({variable: val}), how="cross") newdata = newdata.sort(variable) newdata.datagrid_explicit = list(variables.keys()) # pad pad = [] - vs = get_variables_names(variables = None, model = model, newdata = modeldata) + vs = get_variables_names(variables=None, model=model, newdata=modeldata) for v in vs: if not newdata[v].is_numeric(): uniqs = modeldata[v].unique() @@ -120,7 +129,7 @@ def predictions( if len(pad) > 0: pad = pl.concat(pad) tmp = upcast([newdata, pad]) - newdata = pl.concat(tmp, how = "diagonal") + newdata = pl.concat(tmp, how="diagonal") else: pad = pl.DataFrame() @@ -155,14 +164,16 @@ def inner(x): J = get_jacobian(inner, get_coef(model)) se = get_se(J, V) out = out.with_columns(pl.Series(se).alias("std_error")) - out = get_z_p_ci(out, model, conf_level=conf_level, hypothesis_null=hypothesis_null) + out = get_z_p_ci( + out, model, conf_level=conf_level, hypothesis_null=hypothesis_null + ) out = get_transform(out, transform=transform) out = get_equivalence(out, equivalence=equivalence) out = sort_columns(out, by=by, newdata=newdata) # unpad if "rowid" in out.columns and pad.shape[0] > 0: - out = out[:-pad.shape[0]:] + out = out[: -pad.shape[0] :] out = MarginaleffectsDataFrame(out, by=by, conf_level=conf_level, newdata=newdata) return out diff --git a/marginaleffects/sanity.py b/marginaleffects/sanity.py index 34d1433..bda1f5e 100644 --- a/marginaleffects/sanity.py +++ b/marginaleffects/sanity.py @@ -14,7 +14,9 @@ def sanitize_vcov(vcov, model): V = get_vcov(model, vcov) if V is not None: - assert isinstance(V, np.ndarray), "get_vcov(model) must return None or a NumPy array" + assert isinstance( + V, np.ndarray + ), "get_vcov(model) must return None or a NumPy array" return V @@ -28,7 +30,9 @@ def sanitize_by(by): elif by is False: by = False else: - raise ValueError("The `by` argument must be True, False, a string, or a list of strings.") + raise ValueError( + "The `by` argument must be True, False, a string, or a list of strings." + ) return by @@ -59,9 +63,7 @@ def sanitize_newdata(model, newdata, wts, by=[]): if len(by) > 0: out = out.sort(by) - out = out.with_columns( - pl.Series(range(out.height), dtype=pl.Int32).alias("rowid") - ) + out = out.with_columns(pl.Series(range(out.height), dtype=pl.Int32).alias("rowid")) if wts is not None: if (isinstance(wts, str) is False) or (wts not in out.columns): @@ -82,6 +84,7 @@ def sanitize_newdata(model, newdata, wts, by=[]): return out + def sanitize_comparison(comparison, by, wts=None): out = comparison if by is not False: @@ -128,7 +131,6 @@ def sanitize_comparison(comparison, by, wts=None): HiLo = namedtuple("HiLo", ["variable", "hi", "lo", "lab", "pad", "comparison"]) - def clean_global(k, n): if ( not isinstance(k, list) @@ -143,7 +145,9 @@ def clean_global(k, n): return out -def get_one_variable_hi_lo(variable, value, newdata, comparison, eps, by, wts=None, modeldata=None): +def get_one_variable_hi_lo( + variable, value, newdata, comparison, eps, by, wts=None, modeldata=None +): msg = "`value` must be a numeric, a list of length two, or 'sd'" vartype = get_variable_type(variable, newdata) @@ -284,7 +288,6 @@ def clean(k): return out - def get_categorical_combinations( variable, uniqs, newdata, comparison, combo="reference" ): @@ -388,7 +391,9 @@ def sanitize_variables(variables, model, newdata, comparison, eps, by, wts=None) vlist.sort() for v in vlist: out.append( - get_one_variable_hi_lo(v, None, newdata, comparison, eps, by, wts, modeldata=modeldata) + get_one_variable_hi_lo( + v, None, newdata, comparison, eps, by, wts, modeldata=modeldata + ) ) elif isinstance(variables, dict): @@ -399,7 +404,14 @@ def sanitize_variables(variables, model, newdata, comparison, eps, by, wts=None) else: out.append( get_one_variable_hi_lo( - v, variables[v], newdata, comparison, eps, by, wts, modeldata=modeldata + v, + variables[v], + newdata, + comparison, + eps, + by, + wts, + modeldata=modeldata, ) ) @@ -407,7 +419,9 @@ def sanitize_variables(variables, model, newdata, comparison, eps, by, wts=None) if variables not in newdata.columns: raise ValueError(f"Variable {variables} is not in newdata.") out.append( - get_one_variable_hi_lo(variables, None, newdata, comparison, eps, by, wts, modeldata=modeldata) + get_one_variable_hi_lo( + variables, None, newdata, comparison, eps, by, wts, modeldata=modeldata + ) ) elif isinstance(variables, list): @@ -416,7 +430,9 @@ def sanitize_variables(variables, model, newdata, comparison, eps, by, wts=None) warn(f"Variable {v} is not in newdata.") else: out.append( - get_one_variable_hi_lo(v, None, newdata, comparison, eps, by, wts, modeldata=modeldata) + get_one_variable_hi_lo( + v, None, newdata, comparison, eps, by, wts, modeldata=modeldata + ) ) # unnest list of list of HiLo @@ -425,10 +441,9 @@ def sanitize_variables(variables, model, newdata, comparison, eps, by, wts=None) return out - def sanitize_hypothesis_null(hypothesis): if isinstance(hypothesis, (int, float)): hypothesis_null = hypothesis else: hypothesis_null = 0 - return hypothesis_null \ No newline at end of file + return hypothesis_null diff --git a/marginaleffects/uncertainty.py b/marginaleffects/uncertainty.py index e071a7d..00485e5 100644 --- a/marginaleffects/uncertainty.py +++ b/marginaleffects/uncertainty.py @@ -41,7 +41,9 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0): if "std_error" not in df.columns: return df df = df.with_columns( - ((pl.col("estimate") - float(hypothesis_null)) / pl.col("std_error")).alias("statistic") + ((pl.col("estimate") - float(hypothesis_null)) / pl.col("std_error")).alias( + "statistic" + ) ) if hasattr(model, "df_resid") and isinstance(model.df_resid, float): dof = model.df_resid @@ -69,4 +71,4 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0): ) except: pass - return df \ No newline at end of file + return df diff --git a/marginaleffects/utils.py b/marginaleffects/utils.py index 6542286..aadfbc6 100644 --- a/marginaleffects/utils.py +++ b/marginaleffects/utils.py @@ -96,7 +96,6 @@ def upcast(dfs: list) -> list: return tmp - def get_variable_type(variable, newdata): inttypes = [pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64] if variable not in newdata.columns: @@ -114,4 +113,3 @@ def get_variable_type(variable, newdata): return "numeric" else: raise ValueError(f"Unknown type for `{variable}`: {newdata[variable].dtype}") - diff --git a/tests/test_by.py b/tests/test_by.py index eba7a1e..ed14ae4 100644 --- a/tests/test_by.py +++ b/tests/test_by.py @@ -1,7 +1,5 @@ import polars as pl -import pytest import statsmodels.formula.api as smf -from pytest import approx from marginaleffects import * diff --git a/tests/test_comparisons.py b/tests/test_comparisons.py index 01fca60..02c5939 100644 --- a/tests/test_comparisons.py +++ b/tests/test_comparisons.py @@ -2,10 +2,8 @@ import numpy as np import polars as pl -import statsmodels.api as sm import statsmodels.formula.api as smf from polars.testing import assert_series_equal -from pytest import approx import marginaleffects from marginaleffects import * diff --git a/tests/test_datagrid.py b/tests/test_datagrid.py index e84a489..38efa3a 100644 --- a/tests/test_datagrid.py +++ b/tests/test_datagrid.py @@ -1,5 +1,4 @@ import polars as pl -import statsmodels.formula.api as smf from marginaleffects import * diff --git a/tests/test_slopes.py b/tests/test_slopes.py index 1fc5585..598666b 100644 --- a/tests/test_slopes.py +++ b/tests/test_slopes.py @@ -1,9 +1,7 @@ import polars as pl import statsmodels.formula.api as smf -from polars.testing import assert_series_equal from marginaleffects import * -from marginaleffects.comparisons import estimands from .utilities import * diff --git a/tests/test_statsmodels_mixedlm.py b/tests/test_statsmodels_mixedlm.py index 58b21e9..b3330ad 100644 --- a/tests/test_statsmodels_mixedlm.py +++ b/tests/test_statsmodels_mixedlm.py @@ -1,6 +1,4 @@ -import numpy as np import polars as pl -import statsmodels.api as sm import statsmodels.formula.api as smf from polars.testing import assert_series_equal diff --git a/tests/test_statsmodels_negativebinomial.py b/tests/test_statsmodels_negativebinomial.py index aaf49cc..3557bc1 100644 --- a/tests/test_statsmodels_negativebinomial.py +++ b/tests/test_statsmodels_negativebinomial.py @@ -1,7 +1,6 @@ import polars as pl import statsmodels.formula.api as smf from pytest import approx -from scipy.stats import pearsonr from marginaleffects import * diff --git a/tests/test_statsmodels_wls.py b/tests/test_statsmodels_wls.py index b406ad0..47c1645 100644 --- a/tests/test_statsmodels_wls.py +++ b/tests/test_statsmodels_wls.py @@ -1,4 +1,3 @@ -import numpy as np import polars as pl import statsmodels.formula.api as smf from polars.testing import assert_series_equal diff --git a/tests/utilities.py b/tests/utilities.py index 4b8f0a9..2988a6d 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,7 +1,5 @@ import re -import numpy as np -import polars as pl from marginaleffects import *