Skip to content

Commit

Permalink
Fixed Posterior plot errors with boolean array. (#1707)
Browse files Browse the repository at this point in the history
* Fixed Posterior plot errors with boolean array for matplotlib

* remove plt.draw()

* Update test

* Add a bad type test

* Fix whitespace

* Update bokeh code

* Update test_plots_bokeh.py

* Update CHANGELOG.md

* Update CHANGELOG.md

* Fix lint

Co-authored-by: Ari Hartikainen <[email protected]>
  • Loading branch information
utkarsh-maheshwari and ahartikainen authored Mar 23, 2022
1 parent 479f2a7 commit aeb25cb
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
* Fix R2 implementation ([1666](https://github.com/arviz-devs/arviz/pull/1666))
* Added warning message in `plot_dist_comparison()` in case subplots go over the limit ([1688](https://github.com/arviz-devs/arviz/pull/1688))
* Fix coord value ignoring for default dims ([2001](https://github.com/arviz-devs/arviz/pull/2001))
* Fixed plot_posterior with boolean data ([1707](https://github.com/arviz-devs/arviz/pull/1707))

### Deprecation

Expand Down
22 changes: 17 additions & 5 deletions arviz/plots/backends/bokeh/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,30 @@ def format_axes():
show=False,
)
_, hist, edges = histogram(values, bins="auto")
else:
elif values.dtype.kind == "i" or (values.dtype.kind == "f" and kind == "hist"):
if bins is None:
if values.dtype.kind == "i":
bins = get_bins(values)
else:
bins = "auto"
bins = get_bins(values)
kwargs.setdefault("align", "left")
kwargs.setdefault("color", "blue")
_, hist, edges = histogram(values, bins=bins)
ax.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
)
elif values.dtype.kind == "b":
if bins is None:
bins = "auto"
kwargs.setdefault("color", "blue")

hist = np.array([(~values).sum(), values.sum()])
edges = np.array([-0.5, 0.5, 1.5])
ax.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
)
hdi_prob = "hide"
ax.xaxis.ticker = [0, 1]
ax.xaxis.major_label_overrides = {0: "False", 1: "True"}
else:
raise TypeError("Values must be float, integer or boolean")

format_axes()
max_data = hist.max()
Expand Down
21 changes: 13 additions & 8 deletions arviz/plots/backends/matplotlib/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,23 @@ def format_axes():
rug=False,
show=False,
)
else:
elif values.dtype.kind == "i" or (values.dtype.kind == "f" and kind == "hist"):
if bins is None:
if values.dtype.kind == "i":
xmin = values.min()
xmax = values.max()
bins = get_bins(values)
ax.set_xlim(xmin - 0.5, xmax + 0.5)
else:
bins = "auto"
xmin = values.min()
xmax = values.max()
bins = get_bins(values)
ax.set_xlim(xmin - 0.5, xmax + 0.5)
kwargs.setdefault("align", "left")
kwargs.setdefault("color", "C0")
ax.hist(values, bins=bins, alpha=0.35, **kwargs)
elif values.dtype.kind == "b":
if bins is None:
bins = "auto"
kwargs.setdefault("color", "C0")
ax.bar(["False", "True"], [(~values).sum(), values.sum()], alpha=0.35, **kwargs)
hdi_prob = "hide"
else:
raise TypeError("Values must be float, integer or boolean")

plot_height = ax.get_ylim()[1]

Expand Down
11 changes: 11 additions & 0 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,17 @@ def test_plot_posterior_discrete(discrete_model, kwargs):
assert axes.shape


def test_plot_posterior_boolean():
data = np.random.choice(a=[False, True], size=(4, 100))
axes = plot_posterior(data, backend="bokeh", show=False)
assert axes.shape


def test_plot_posterior_bad_type():
with pytest.raises(TypeError):
plot_posterior(np.array(["a", "b", "c"]), backend="bokeh", show=False)


def test_plot_posterior_bad(models):
with pytest.raises(ValueError):
plot_posterior(models.model_1, backend="bokeh", show=False, rope="bad_value")
Expand Down
14 changes: 14 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,12 +1025,26 @@ def test_plot_posterior(models, kwargs):
assert axes.shape


def test_plot_posterior_boolean():
data = np.random.choice(a=[False, True], size=(4, 100))
axes = plot_posterior(data)
assert axes
plt.draw()
labels = [label.get_text() for label in axes.get_xticklabels()]
assert all(item in labels for item in ("True", "False"))


@pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
def test_plot_posterior_discrete(discrete_model, kwargs):
axes = plot_posterior(discrete_model, **kwargs)
assert axes.shape


def test_plot_posterior_bad_type():
with pytest.raises(TypeError):
plot_posterior(np.array(["a", "b", "c"]))


def test_plot_posterior_bad(models):
with pytest.raises(ValueError):
plot_posterior(models.model_1, rope="bad_value")
Expand Down

0 comments on commit aeb25cb

Please sign in to comment.