diff --git a/xbitinfo/graphics.py b/xbitinfo/graphics.py index da89aacd..0a95e503 100644 --- a/xbitinfo/graphics.py +++ b/xbitinfo/graphics.py @@ -287,8 +287,17 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): subfigure_data[d]["nbits"] = (n_sign, n_exp, n_bits, n_mant, nonmantissa_bits) subfigure_data[d]["bits_to_show"] = bits_to_show - total_fig_height = np.sum([d["fig_height"] for d in subfigure_data]) - fig, axs = plt.subplots(len(subfigure_data), 1, figsize=(12, total_fig_height)) + fig_heights = [subfig["fig_height"] for subfig in subfigure_data] + fig = plt.figure(figsize=(12, sum(fig_heights) + 2)) + fig_heights_incl_cax = fig_heights + [2 / (sum(fig_heights) + 1)] + grid = fig.add_gridspec( + len(subfigure_data) + 1, 1, height_ratios=fig_heights_incl_cax + ) + + axs = [] + for i in range(len(subfigure_data) + 1): + ax = fig.add_subplot(grid[i, 0]) + axs.append(ax) if isinstance(axs, plt.Axes): axs = [axs] @@ -332,10 +341,7 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): pcm = axs[d].pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) if d == len(subfigure_data) - 1: - pos = axs[d].get_position() - cax = fig.add_axes([pos.x0, 0.12, pos.x1 - pos.x0, 0.05]) - lax = fig.add_axes([pos.x0, 0.07, pos.x1 - pos.x0, 0.07]) - lax.axis("off") + cax = axs[len(subfigure_data)] cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") cbar.set_label("information content [bit]") @@ -475,16 +481,14 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): axs[d].text(i + 0.5, nvars + 0.5, m + 1, ha="center", fontsize=7) if d == len(subfigure_data) - 1: - lax.legend( - bbox_to_anchor=(0.5, 0), - loc="center", + fig.legend( + loc="lower center", framealpha=0.6, ncol=3, handles=[l1, l2, l0[0]], ) axs[d].set_xlim(0, bits_to_show) - plt.tight_layout() fig.show() return fig