From 0a9ce72ba2101f99fea43dcd43b1f498dbb8e558 Mon Sep 17 00:00:00 2001 From: Simon Breuer <86068340+sibre28@users.noreply.github.com> Date: Fri, 21 Apr 2023 20:07:58 +0200 Subject: [PATCH] feat: Added `Table.transform_table` method which returns the transformed Table (#229) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #110. ### Summary of Changes Added `Table.transform_table` Method which returns a Table transformed with the given `TableTransformer` Co-authored-by: Marsmaennchen221 <47296670+Marsmaennchen221@users.noreply.github.com> --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Alexander Gréus Co-authored-by: Lars Reimann --- src/safeds/data/tabular/containers/_table.py | 38 +++++- .../_table/test_inverse_transform_table.py | 2 +- .../containers/_table/test_transform_table.py | 120 ++++++++++++++++++ 3 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 tests/safeds/data/tabular/containers/_table/test_transform_table.py diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index a37f4c7ae..cf3a4693a 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable - from safeds.data.tabular.transformation import InvertibleTableTransformer + from safeds.data.tabular.transformation import InvertibleTableTransformer, TableTransformer from ._tagged_table import TaggedTable @@ -993,6 +993,40 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl return self.replace_column(name, result) raise UnknownColumnNameError([name]) + def transform_table(self, transformer: TableTransformer) -> Table: + """ + Apply a learned transformation onto this table. + + Parameters + ---------- + transformer : TableTransformer + The transformer which transforms the given table. + + Returns + ------- + transformed_table : Table + The transformed table. + + Raises + ------ + TransformerNotFittedError + If the transformer has not been fitted yet. + + Examples + -------- + >>> from safeds.data.tabular.transformation import OneHotEncoder + >>> from safeds.data.tabular.containers import Table + >>> transformer = OneHotEncoder() + >>> table = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]}) + >>> transformer = transformer.fit(table, None) + >>> table.transform_table(transformer) + col1_1 col1_2 col2_1 col2_2 col2_4 + 0 1.0 0.0 1.0 0.0 0.0 + 1 0.0 1.0 0.0 1.0 0.0 + 2 1.0 0.0 0.0 0.0 1.0 + """ + return transformer.transform(self) + def inverse_transform_table(self, transformer: InvertibleTableTransformer) -> Table: """ Invert the transformation applied by the given transformer. @@ -1005,7 +1039,7 @@ def inverse_transform_table(self, transformer: InvertibleTableTransformer) -> Ta Returns ------- table : Table - The original table + The original table. Raises ------ diff --git a/tests/safeds/data/tabular/containers/_table/test_inverse_transform_table.py b/tests/safeds/data/tabular/containers/_table/test_inverse_transform_table.py index 84ee05527..8d49a0474 100644 --- a/tests/safeds/data/tabular/containers/_table/test_inverse_transform_table.py +++ b/tests/safeds/data/tabular/containers/_table/test_inverse_transform_table.py @@ -4,7 +4,7 @@ from safeds.data.tabular.transformation import OneHotEncoder -class TestInverseTransformTableOnOneHotEncoder: +class TestInverseTransformTable: @pytest.mark.parametrize( ("table_to_fit", "column_names", "table_to_transform"), [ diff --git a/tests/safeds/data/tabular/containers/_table/test_transform_table.py b/tests/safeds/data/tabular/containers/_table/test_transform_table.py new file mode 100644 index 000000000..59ffe90d3 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/test_transform_table.py @@ -0,0 +1,120 @@ +import pytest +from safeds.data.tabular.containers import Table +from safeds.data.tabular.exceptions import TransformerNotFittedError, UnknownColumnNameError +from safeds.data.tabular.transformation import OneHotEncoder + + +class TestTransform: + @pytest.mark.parametrize( + ("table", "column_names", "expected"), + [ + ( + Table.from_dict( + { + "col1": ["a", "b", "b", "c"], + }, + ), + None, + Table.from_dict( + { + "col1_a": [1.0, 0.0, 0.0, 0.0], + "col1_b": [0.0, 1.0, 1.0, 0.0], + "col1_c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ), + ( + Table.from_dict( + { + "col1": ["a", "b", "b", "c"], + "col2": ["a", "b", "b", "c"], + }, + ), + ["col1"], + Table.from_dict( + { + "col1_a": [1.0, 0.0, 0.0, 0.0], + "col1_b": [0.0, 1.0, 1.0, 0.0], + "col1_c": [0.0, 0.0, 0.0, 1.0], + "col2": ["a", "b", "b", "c"], + }, + ), + ), + ( + Table.from_dict( + { + "col1": ["a", "b", "b", "c"], + "col2": ["a", "b", "b", "c"], + }, + ), + ["col1", "col2"], + Table.from_dict( + { + "col1_a": [1.0, 0.0, 0.0, 0.0], + "col1_b": [0.0, 1.0, 1.0, 0.0], + "col1_c": [0.0, 0.0, 0.0, 1.0], + "col2_a": [1.0, 0.0, 0.0, 0.0], + "col2_b": [0.0, 1.0, 1.0, 0.0], + "col2_c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ), + ], + ids=["all columns", "one column", "multiple columns"], + ) + def test_should_return_transformed_table( + self, + table: Table, + column_names: list[str] | None, + expected: Table, + ) -> None: + transformer = OneHotEncoder().fit(table, column_names) + assert table.transform_table(transformer) == expected + + def test_should_not_change_original_table(self) -> None: + table = Table.from_dict( + { + "col1": ["a", "b", "c"], + }, + ) + + transformer = OneHotEncoder().fit(table, None) + table.transform_table(transformer) + + expected = Table.from_dict( + { + "col1": ["a", "b", "c"], + }, + ) + + assert table == expected + + def test_should_raise_if_column_not_found(self) -> None: + table_to_fit = Table.from_dict( + { + "col1": ["a", "b", "c"], + }, + ) + + transformer = OneHotEncoder().fit(table_to_fit, None) + + table_to_transform = Table.from_dict( + { + "col2": ["a", "b", "c"], + }, + ) + + with pytest.raises(UnknownColumnNameError): + table_to_transform.transform_table(transformer) + + def test_should_raise_if_not_fitted(self) -> None: + table = Table.from_dict( + { + "col1": ["a", "b", "c"], + }, + ) + + transformer = OneHotEncoder() + + with pytest.raises(TransformerNotFittedError): + table.transform_table(transformer)