diff --git a/marginaleffects/plot_common.py b/marginaleffects/plot_common.py index 09d1a84..ff5e2a4 100644 --- a/marginaleffects/plot_common.py +++ b/marginaleffects/plot_common.py @@ -11,15 +11,12 @@ def dt_on_condition(model, condition): model = sanitize_model(model) + # not sure why newdata gets added modeldata = model.modeldata if isinstance(condition, str): condition = [condition] - assert ( - 1 <= len(condition) <= 3 - ), f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}." - to_datagrid = {} first_key = "" # special case when the first element is numeric @@ -37,6 +34,12 @@ def dt_on_condition(model, condition): first_key = next(iter(condition)) to_datagrid = condition + # not sure why `newdata` sometimes gets added + condition.pop("newdata", None) + assert ( + 1 <= len(condition) <= 3 + ), f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}." + for key, value in to_datagrid.items(): variable_type = get_variable_type(key, modeldata) @@ -209,14 +212,13 @@ def plot_common(dt, y_label, var_list): ) else: title = dim_max_j - title += ( - "\n" - + subplot_dt.select(pl.first("term")).item() - + ", " - + dim_min_i - if dim_min_i is not None - else "" - ) + if dim_min_i is not None: + title += ( + "\n" + + subplot_dt.select(pl.first("term")).item() + + ", " + + dim_min_i + ) fig.axes[axe].set_title(title, fontsize=titles_fontsize) diff --git a/pyproject.toml b/pyproject.toml index c81238e..c0a938f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ polars = ">0.18.3" pyarrow = "^14.0.1" scipy = "^1.10.0" matplotlib = "^3.7.2" -pyfixest = ">= 0.11.1" +pyfixest = ">= 0.11.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" diff --git a/tests/test_plot_predictions.py b/tests/test_plot_predictions.py index 60029fc..6772891 100644 --- a/tests/test_plot_predictions.py +++ b/tests/test_plot_predictions.py @@ -63,3 +63,18 @@ def test_issue_57(): fig = plot_predictions(mod, condition={"am": None, "wt": None}) assert assert_image(fig, "issue_57_04", "plot_predictions") is None + + +def issue_62(): + import types + mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv") + mod = smf.ols("mpg ~ hp * wt * am", data = mtcars).fit() + cond = { + "hp": None, + "wt": [mtcars["wt"].mean() - mtcars["wt"].std(), + mtcars["wt"].mean(), + mtcars["wt"].mean() + mtcars["wt"].std()], + "am": None + } + p = plot_predictions(mod, condition = cond) + assert isinstance(p, types.ModuleType) \ No newline at end of file