diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c9c84382..abedf2abd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - -- +- Add xticks parameter for plot_periodogram, clip frequencies to be >= 1 ([#706](https://github.com/tinkoff-ai/etna/pull/706)) - - - diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index d8be95fac..797e03339 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -1190,6 +1190,7 @@ def plot_periodogram( amplitude_aggregation_mode: Union[str, Literal["per-segment"]] = AggregationMode.mean, periodogram_params: Optional[Dict[str, Any]] = None, segments: Optional[List[str]] = None, + xticks: Optional[List[Any]] = None, columns_num: int = 2, figsize: Tuple[int, int] = (10, 5), ): @@ -1213,6 +1214,8 @@ def plot_periodogram( additional keyword arguments for periodogram, :py:func:`scipy.signal.periodogram` is used segments: segments to use + xticks: + list of tick locations of the x-axis, useful to highlight specific reference periodicities columns_num: if ``amplitude_aggregation_mode="per-segment"`` number of columns in subplots, otherwise the value is ignored figsize: @@ -1247,10 +1250,14 @@ def plot_periodogram( if segment_df.isna().any(): raise ValueError(f"Periodogram can't be calculated on segment with NaNs inside: {segment}") frequencies, spectrum = periodogram(x=segment_df, fs=period, **periodogram_params) + spectrum = spectrum[frequencies >= 1] + frequencies = frequencies[frequencies >= 1] ax[i].step(frequencies, spectrum) ax[i].set_xscale("log") ax[i].set_xlabel("Frequency") ax[i].set_ylabel("Power spectral density") + if xticks is not None: + ax[i].set_xticks(ticks=xticks, labels=xticks) ax[i].set_title(f"Periodogram: {segment}") else: # find length of each segment @@ -1276,11 +1283,15 @@ def plot_periodogram( frequencies = frequencies_segments[0] amplitude_aggregation_fn = AGGREGATION_FN[AggregationMode(amplitude_aggregation_mode)] spectrum = amplitude_aggregation_fn(spectrums_segments, axis=0) # type: ignore + spectrum = spectrum[frequencies >= 1] + frequencies = frequencies[frequencies >= 1] _, ax = plt.subplots(figsize=figsize, constrained_layout=True) ax.step(frequencies, spectrum) # type: ignore ax.set_xscale("log") # type: ignore ax.set_xlabel("Frequency") # type: ignore ax.set_ylabel("Power spectral density") # type: ignore + if xticks is not None: + ax.set_xticks(ticks=xticks, labels=xticks) # type: ignore ax.set_title("Periodogram") # type: ignore