Skip to content

Commit

Permalink
fix passing axes to plot_density with several datasets (#1198)
Browse files Browse the repository at this point in the history
* fix passing axes to plot_density with several datasets

* update changelog
  • Loading branch information
aloctavodia authored May 21, 2020
1 parent b062cbc commit 831a43c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
* Set `fill_last` argument of `plot_kde` to False by default (#1158)
* plot_ppc animation: improve docs and error handling (#1162)
* Fix import error when wrapped function docstring is empty (#1192)
* Fix passing axes to plot_density with several datasets ([#1198](https://github.com/arviz-devs/arviz/pull/1198))

### Deprecation
* `hpd` function deprecated in favor of `hdi`. `credible_interval` argument replaced by `hdi_prob`throughout with exception of `plot_loo_pit` (#1176)
Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/matplotlib/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def plot_density(

if n_data > 1:
for m_idx, label in enumerate(data_labels):
ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize)
ax[0].legend(fontsize=xt_labelsize)
ax.item(0).plot([], label=label, c=colors[m_idx], markersize=markersize)
ax.item(0).legend(fontsize=xt_labelsize)

if backend_show(show):
plt.show()
Expand Down
6 changes: 5 additions & 1 deletion arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,16 @@ def fig_ax():
{"hdi_markers": ["v"]},
{"shade": 1},
{"transform": lambda x: x + 1},
{"ax": plt.subplots(6, 3)[1]},
],
)
def test_plot_density_float(models, kwargs):
obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
axes = plot_density(obj, **kwargs)
assert axes.shape[0] >= 18
if "ax" in kwargs:
assert axes.shape == (6, 3)
else:
assert axes.shape[0] >= 18


def test_plot_density_discrete(discrete_model):
Expand Down

0 comments on commit 831a43c

Please sign in to comment.