diff --git a/CHANGELOG.md b/CHANGELOG.md index 0531275980..f37b93b0df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### New features ### Maintenance and fixes - +* Fixed ovelapping titles and repeating warnings on circular traceplot ([1517](https://github.com/arviz-devs/arviz/pull/1517)) ### Deprecation ### Documentation diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py index 46991417b9..baf54fb1a2 100644 --- a/arviz/plots/backends/matplotlib/kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -2,6 +2,7 @@ import numpy as np from matplotlib import pyplot as plt from matplotlib import _pylab_helpers +import matplotlib.ticker as mticker from ...plot_utils import _scale_fig_size @@ -101,6 +102,8 @@ def plot_kde( f"{-np.pi/4:.2f}", ] + ticks_loc = ax.get_xticks() + ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) ax.set_xticklabels(labels) x = np.linspace(-np.pi, np.pi, len(density)) diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index e491899f03..3f25209f43 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.lines import Line2D +import matplotlib.ticker as mticker from ....stats.density_utils import get_bins from ...distplot import plot_dist @@ -318,9 +319,14 @@ def plot_trace( if value[0].dtype.kind == "i" and idy == 0: xticks = get_bins(value) ax.set_xticks(xticks[:-1]) + y = 1 / textsize if not idy: ax.set_yticks([]) - ax.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True, y=1) + if circular: + y = 0.13 if selection else 0.12 + ax.set_title( + make_label(var_name, selection), fontsize=titlesize, wrap=True, y=textsize * y + ) ax.tick_params(labelsize=xt_labelsize) xlims = ax.get_xlim() @@ -471,6 +477,7 @@ def _plot_chains_mpl( if circ_units_trace == "degrees": y_tick_locs = axes.get_yticks() y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)] + axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs)) axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels]) if not combined: