Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/tinkoff-ai/etna into issu…
Browse files Browse the repository at this point in the history
…e-675
  • Loading branch information
Artem Makhin committed May 25, 2022
2 parents fb6d844 + 91edf5b commit b1ca2d3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Add xticks parameter for plot_periodogram, clip frequencies to be >= 1 ([#706](https://github.com/tinkoff-ai/etna/pull/706))
-
-
-
Expand Down
11 changes: 11 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,7 @@ def plot_periodogram(
amplitude_aggregation_mode: Union[str, Literal["per-segment"]] = AggregationMode.mean,
periodogram_params: Optional[Dict[str, Any]] = None,
segments: Optional[List[str]] = None,
xticks: Optional[List[Any]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
Expand All @@ -1213,6 +1214,8 @@ def plot_periodogram(
additional keyword arguments for periodogram, :py:func:`scipy.signal.periodogram` is used
segments:
segments to use
xticks:
list of tick locations of the x-axis, useful to highlight specific reference periodicities
columns_num:
if ``amplitude_aggregation_mode="per-segment"`` number of columns in subplots, otherwise the value is ignored
figsize:
Expand Down Expand Up @@ -1247,10 +1250,14 @@ def plot_periodogram(
if segment_df.isna().any():
raise ValueError(f"Periodogram can't be calculated on segment with NaNs inside: {segment}")
frequencies, spectrum = periodogram(x=segment_df, fs=period, **periodogram_params)
spectrum = spectrum[frequencies >= 1]
frequencies = frequencies[frequencies >= 1]
ax[i].step(frequencies, spectrum)
ax[i].set_xscale("log")
ax[i].set_xlabel("Frequency")
ax[i].set_ylabel("Power spectral density")
if xticks is not None:
ax[i].set_xticks(ticks=xticks, labels=xticks)
ax[i].set_title(f"Periodogram: {segment}")
else:
# find length of each segment
Expand All @@ -1276,11 +1283,15 @@ def plot_periodogram(
frequencies = frequencies_segments[0]
amplitude_aggregation_fn = AGGREGATION_FN[AggregationMode(amplitude_aggregation_mode)]
spectrum = amplitude_aggregation_fn(spectrums_segments, axis=0) # type: ignore
spectrum = spectrum[frequencies >= 1]
frequencies = frequencies[frequencies >= 1]
_, ax = plt.subplots(figsize=figsize, constrained_layout=True)
ax.step(frequencies, spectrum) # type: ignore
ax.set_xscale("log") # type: ignore
ax.set_xlabel("Frequency") # type: ignore
ax.set_ylabel("Power spectral density") # type: ignore
if xticks is not None:
ax.set_xticks(ticks=xticks, labels=xticks) # type: ignore
ax.set_title("Periodogram") # type: ignore


Expand Down

0 comments on commit b1ca2d3

Please sign in to comment.