diff --git a/src/safeds/data/tabular/plotting/_table_plotter.py b/src/safeds/data/tabular/plotting/_table_plotter.py index a58b3c8f2..6be3bc095 100644 --- a/src/safeds/data/tabular/plotting/_table_plotter.py +++ b/src/safeds/data/tabular/plotting/_table_plotter.py @@ -5,6 +5,7 @@ from safeds._utils import _figure_to_image from safeds._validation import _check_columns_exist +from safeds._validation._check_columns_are_numeric import _check_columns_are_numeric from safeds.exceptions import NonNumericColumnError if TYPE_CHECKING: @@ -322,12 +323,7 @@ def scatter_plot(self, x_name: str, y_name: str) -> Image: >>> image = table.plot.scatter_plot("a", "b") """ _check_columns_exist(self._table, [x_name, y_name]) - - # TODO: pass list of columns names + extract validation - if not self._table.get_column(x_name).is_numeric: - raise NonNumericColumnError(x_name) - if not self._table.get_column(y_name).is_numeric: - raise NonNumericColumnError(y_name) + _check_columns_are_numeric(self._table, [x_name, y_name]) import matplotlib.pyplot as plt @@ -335,6 +331,9 @@ def scatter_plot(self, x_name: str, y_name: str) -> Image: ax.scatter( x=self._table.get_column(x_name)._series, y=self._table.get_column(y_name)._series, + s=64, # marker size + linewidth=1, + edgecolor="white", ) ax.set( xlabel=x_name, diff --git a/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot.png b/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot.png deleted file mode 100644 index c2dcdd67d..000000000 Binary files a/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot.png and /dev/null differ diff --git a/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[functional].png b/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[functional].png new file mode 100644 index 000000000..29e6d049e Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[functional].png differ diff --git a/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[overlapping].png b/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[overlapping].png new file mode 100644 index 000000000..251a375e3 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/__snapshots__/test_plot_scatterplot/test_should_match_snapshot[overlapping].png differ diff --git a/tests/safeds/data/tabular/containers/_table/test_plot_scatterplot.py b/tests/safeds/data/tabular/containers/_table/test_plot_scatterplot.py index 457f0b368..c0b46fa79 100644 --- a/tests/safeds/data/tabular/containers/_table/test_plot_scatterplot.py +++ b/tests/safeds/data/tabular/containers/_table/test_plot_scatterplot.py @@ -1,12 +1,31 @@ import pytest from safeds.data.tabular.containers import Table -from safeds.exceptions import ColumnNotFoundError +from safeds.exceptions import ColumnNotFoundError, ColumnTypeError from syrupy import SnapshotAssertion -def test_should_match_snapshot(snapshot_png_image: SnapshotAssertion) -> None: - table = Table({"A": [1, 2, 3], "B": [2, 4, 7]}) - scatterplot = table.plot.scatter_plot("A", "B") +@pytest.mark.parametrize( + ("table", "x_name", "y_name"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B"), + ( + Table( + { + "A": [1, 0.99, 0.99, 2], + "B": [1, 0.99, 1.01, 2], + }, + ), + "A", + "B", + ), + ], + ids=[ + "functional", + "overlapping", + ], +) +def test_should_match_snapshot(table: Table, x_name: str, y_name: str, snapshot_png_image: SnapshotAssertion) -> None: + scatterplot = table.plot.scatter_plot(x_name, y_name) assert scatterplot == snapshot_png_image @@ -23,3 +42,15 @@ def test_should_match_snapshot(snapshot_png_image: SnapshotAssertion) -> None: def test_should_raise_if_column_does_not_exist(table: Table, col1: str, col2: str) -> None: with pytest.raises(ColumnNotFoundError): table.plot.scatter_plot(col1, col2) + + +@pytest.mark.parametrize( + ("table", "x_name", "y_name"), + [ + (Table({"A": ["a", "b", "c"], "B": [2, 4, 7]}), "A", "B"), + (Table({"A": [1, 2, 3], "B": ["a", "b", "c"]}), "A", "B"), + ], +) +def test_should_raise_if_columns_are_not_numeric(table: Table, x_name: str, y_name: str) -> None: + with pytest.raises(ColumnTypeError): + table.plot.scatter_plot(x_name, y_name)