Skip to content

Commit

Permalink
Fix ess/rhat plot in plot_forest (arviz-devs#1606)
Browse files Browse the repository at this point in the history
* Fix ess/rhat plot in plot_forest

* Aligned ess/rhat dots in forestplot bokeh

* Minor changes

Co-authored-by: Oriol Abril-Pla <[email protected]>
  • Loading branch information
2 people authored and utkarsh-maheshwari committed May 27, 2021
1 parent 2569b59 commit e76d041
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Fix `c` argument in `plot_khat` ([1592](https://github.com/arviz-devs/arviz/pull/1592))
* Fix `ax` argument in `plot_elpd` ([1593](https://github.com/arviz-devs/arviz/pull/1593))
* Remove warning in `stats.py` compare function ([1607](https://github.com/arviz-devs/arviz/pull/1607))
* Fix `ess/rhat` plots in `plot_forest` ([1606](https://github.com/arviz-devs/arviz/pull/1606))
* Fix `from_numpyro` crash when importing model with `thinning=x` for `x > 1` ([1619](https://github.com/arviz-devs/arviz/pull/1619))

### Deprecation
Expand Down
53 changes: 30 additions & 23 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def plot_forest(

for i, width_r in zip(range(ncols), width_ratios):
backend_kwargs_i = backend_kwargs.copy()
backend_kwargs_i.setdefault("width", int(figsize[0] * dpi))
backend_kwargs_i.setdefault("height", int(figsize[1] * dpi))
backend_kwargs_i.setdefault(
"height", int(figsize[1] * (width_r / sum(width_ratios)) * dpi * 1.25)
"width", int(figsize[0] * (width_r / sum(width_ratios)) * dpi * 1.25)
)
if i == 0:
ax = bkp.figure(
**backend_kwargs_i,
)
backend_kwargs_i.setdefault("y_range", ax.y_range)
backend_kwargs.setdefault("y_range", ax.y_range)
else:
ax = bkp.figure(**backend_kwargs_i)
axes.append(ax)
Expand Down Expand Up @@ -172,6 +172,11 @@ def plot_forest(
plot_handler.legend(axes[0, idx], plotted_r_hat)
idx += 1

all_plotters = list(plot_handler.plotters.values())
y_max = plot_handler.y_max() - all_plotters[-1].group_offset
if kind == "ridgeplot": # space at the top
y_max += ridgeplot_overlap

for i, ax_ in enumerate(axes.ravel()):
if kind == "ridgeplot":
ax_.xgrid.grid_line_color = None
Expand All @@ -186,24 +191,17 @@ def plot_forest(
ax_.x_range = DataRange1d(bounds=backend_config["bounds_x_range"], min_interval=1)
ax_.y_range = DataRange1d(bounds=backend_config["bounds_y_range"], min_interval=2)

ax_.y_range._property_values["start"] = -all_plotters[ # pylint: disable=protected-access
0
].group_offset
ax_.y_range._property_values["end"] = y_max # pylint: disable=protected-access

labels, ticks = plot_handler.labels_and_ticks()
ticks = [int(tick) if (tick).is_integer() else tick for tick in ticks]

axes[0, 0].yaxis.ticker = FixedTicker(ticks=ticks)
axes[0, 0].yaxis.major_label_overrides = dict(zip(map(str, ticks), map(str, labels)))

all_plotters = list(plot_handler.plotters.values())
y_max = plot_handler.y_max() - all_plotters[-1].group_offset
if kind == "ridgeplot": # space at the top
y_max += ridgeplot_overlap

axes[0, 0].y_range._property_values[
"start"
] = -all_plotters[ # pylint: disable=protected-access
0
].group_offset
axes[0, 0].y_range._property_values["end"] = y_max # pylint: disable=protected-access

if legend:
plot_handler.legend(axes[0, 0], plotted)
show_layout(axes, show)
Expand Down Expand Up @@ -283,6 +281,14 @@ def label_idxs():
labels, idxs = [], []
for plotter in val:
sub_labels, sub_idxs, _, _, _ = plotter.labels_ticks_and_vals()
labels_to_idxs = defaultdict(list)
for label, idx in zip(sub_labels, sub_idxs):
labels_to_idxs[label].append(idx)
sub_idxs = []
sub_labels = []
for label, all_idx in labels_to_idxs.items():
sub_labels.append(label)
sub_idxs.append(np.mean([j for j in all_idx]))
labels.append(sub_labels)
idxs.append(sub_idxs)
return np.concatenate(labels), np.concatenate(idxs)
Expand All @@ -295,8 +301,8 @@ def legend(self, ax, plotted):
for (model_name, glyphs) in plotted.items():
legend_it.append((model_name, glyphs))

legend = Legend(items=legend_it)
ax.add_layout(legend, "right")
legend = Legend(items=legend_it, orientation="vertical", location="top_left")
ax.add_layout(legend, "above")
ax.legend.click_policy = "hide"

def display_multiple_ropes(
Expand Down Expand Up @@ -675,12 +681,13 @@ def labels_ticks_and_vals(self):
for y, label, model_name, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color, model_name))
labels, ticks, vals, colors, model_names = [], [], [], [], []
for label, data in y_ticks.items():
labels.append(label)
ticks.append(np.mean([j[0] for j in data]))
vals.append(np.vstack([j[1] for j in data]))
model_names.append(data[0][3])
colors.append(data[0][2]) # the colors are all the same
for label, all_data in y_ticks.items():
for data in all_data:
labels.append(label)
ticks.append(data[0])
vals.append(np.array(data[1]))
model_names.append(data[3])
colors.append(data[2]) # the colors are all the same
return labels, ticks, vals, colors, model_names

def treeplot(self, qlist, hdi_prob):
Expand Down
19 changes: 14 additions & 5 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ def label_idxs():
labels, idxs = [], []
for plotter in val:
sub_labels, sub_idxs, _, _ = plotter.labels_ticks_and_vals()
labels_to_idxs = defaultdict(list)
for label, idx in zip(sub_labels, sub_idxs):
labels_to_idxs[label].append(idx)
sub_idxs = []
sub_labels = []
for label, all_idx in labels_to_idxs.items():
sub_labels.append(label)
sub_idxs.append(np.mean([j for j in all_idx]))
labels.append(sub_labels)
idxs.append(sub_idxs)
return np.concatenate(labels), np.concatenate(idxs)
Expand Down Expand Up @@ -567,11 +575,12 @@ def labels_ticks_and_vals(self):
for y, label, _, _, vals, color in self.iterator():
y_ticks[label].append((y, vals, color))
labels, ticks, vals, colors = [], [], [], []
for label, data in y_ticks.items():
labels.append(label)
ticks.append(np.mean([j[0] for j in data]))
vals.append(np.vstack([j[1] for j in data]))
colors.append(data[0][2]) # the colors are all the same
for label, all_data in y_ticks.items():
for data in all_data:
labels.append(label)
ticks.append(data[0])
vals.append(np.array(data[1]))
colors.append(data[2]) # the colors are all the same
return labels, ticks, vals, colors

def treeplot(self, qlist, hdi_prob):
Expand Down

0 comments on commit e76d041

Please sign in to comment.