Skip to content

Commit

Permalink
Issue Get started bug #62
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Dec 26, 2023
1 parent 969fcda commit 2f505ea
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
26 changes: 14 additions & 12 deletions marginaleffects/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
def dt_on_condition(model, condition):
model = sanitize_model(model)

# not sure why newdata gets added
modeldata = model.modeldata

if isinstance(condition, str):
condition = [condition]

assert (
1 <= len(condition) <= 3
), f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}."

to_datagrid = {}
first_key = "" # special case when the first element is numeric

Expand All @@ -37,6 +34,12 @@ def dt_on_condition(model, condition):
first_key = next(iter(condition))
to_datagrid = condition

# not sure why `newdata` sometimes gets added
condition.pop("newdata", None)
assert (
1 <= len(condition) <= 3
), f"Lenght of condition must be inclusively between 1 and 3. Got : {len(condition)}."

for key, value in to_datagrid.items():
variable_type = get_variable_type(key, modeldata)

Expand Down Expand Up @@ -209,14 +212,13 @@ def plot_common(dt, y_label, var_list):
)
else:
title = dim_max_j
title += (
"\n"
+ subplot_dt.select(pl.first("term")).item()
+ ", "
+ dim_min_i
if dim_min_i is not None
else ""
)
if dim_min_i is not None:
title += (
"\n"
+ subplot_dt.select(pl.first("term")).item()
+ ", "
+ dim_min_i
)

fig.axes[axe].set_title(title, fontsize=titles_fontsize)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ polars = ">0.18.3"
pyarrow = "^14.0.1"
scipy = "^1.10.0"
matplotlib = "^3.7.2"
pyfixest = ">= 0.11.1"
pyfixest = ">= 0.11.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
15 changes: 15 additions & 0 deletions tests/test_plot_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ def test_issue_57():
fig = plot_predictions(mod, condition={"am": None, "wt": None})
assert assert_image(fig, "issue_57_04", "plot_predictions") is None



def issue_62():
import types
mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv")
mod = smf.ols("mpg ~ hp * wt * am", data = mtcars).fit()
cond = {
"hp": None,
"wt": [mtcars["wt"].mean() - mtcars["wt"].std(),
mtcars["wt"].mean(),
mtcars["wt"].mean() + mtcars["wt"].std()],
"am": None
}
p = plot_predictions(mod, condition = cond)
assert isinstance(p, types.ModuleType)

0 comments on commit 2f505ea

Please sign in to comment.