diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index e776d13c9..04b998294 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -1,7 +1,7 @@ from __future__ import annotations -import sys import io +import sys from collections.abc import Sequence from numbers import Number from typing import TYPE_CHECKING, Any, TypeVar, overload diff --git a/src/safeds/data/tabular/containers/_time_series.py b/src/safeds/data/tabular/containers/_time_series.py index 491b7d68b..fee39fd94 100644 --- a/src/safeds/data/tabular/containers/_time_series.py +++ b/src/safeds/data/tabular/containers/_time_series.py @@ -1,13 +1,19 @@ from __future__ import annotations +import io import sys from typing import TYPE_CHECKING +import matplotlib.pyplot as plt +import pandas as pd + +from safeds.data.image.containers import Image from safeds.data.tabular.containers import Column, Row, Table, TaggedTable from safeds.exceptions import ( ColumnIsTargetError, ColumnIsTimeError, IllegalSchemaModificationError, + NonNumericColumnError, UnknownColumnNameError, ) @@ -839,6 +845,13 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time The original time series is not modified. + Parameters + ---------- + name: + The name of the column to be transformed. + transformer: + The transformer to the given column + Returns ------- result : TimeSeries @@ -857,3 +870,39 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time ), time_name=self.time.name, ) + + def plot_lagplot(self, lag: int) -> Image: + """ + Plot a lagplot for the target column. + + Parameters + ---------- + lag: + The amount of lag used to plot + + Returns + ------- + plot: + The plot as an image. + + Raises + ------ + NonNumericColumnError + If the time series targets contains non-numerical values. + + Examples + -------- + >>> from safeds.data.tabular.containers import TimeSeries + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) + >>> image = table.plot_lagplot(lag = 1) + + """ + if not self.target.type.is_numeric(): + raise NonNumericColumnError("This time series target contains non-numerical columns.") + ax = pd.plotting.lag_plot(self.target._data, lag=lag) + fig = ax.figure + buffer = io.BytesIO() + fig.savefig(buffer, format="png") + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image.from_bytes(buffer.read()) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag/test_should_return_table.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag/test_should_return_table.png new file mode 100644 index 000000000..0f17b4726 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lag/test_should_return_table.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py new file mode 100644 index 000000000..cb6c94809 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lag.py @@ -0,0 +1,41 @@ +import pytest +from safeds.data.tabular.containers import TimeSeries +from safeds.exceptions import NonNumericColumnError +from syrupy import SnapshotAssertion + + +def test_should_return_table(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + lag_plot = table.plot_lagplot(lag=1) + assert lag_plot == snapshot_png + + +def test_should_raise_if_column_contains_non_numerical_values() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series target" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_lagplot(2)