Skip to content

Commit

Permalink
feat: Added method Table.inverse_transform_table which returns the …
Browse files Browse the repository at this point in the history
…original table (#227)

Closes #111.

### Summary of Changes

Added method `Table.inverse_transform_table` which takes the fitted
transformer and returns the original table

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
Co-authored-by: sibre28 <[email protected]>
  • Loading branch information
4 people authored Apr 21, 2023
1 parent b3893cc commit 846bf23
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from safeds.data.tabular.transformation import InvertibleTableTransformer

from ._tagged_table import TaggedTable


Expand Down Expand Up @@ -991,6 +993,46 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

def inverse_transform_table(self, transformer: InvertibleTableTransformer) -> Table:
"""
Invert the transformation applied by the given transformer.
Parameters
----------
transformer : InvertibleTableTransformer
A transformer that was fitted with columns, which are all present in the table.
Returns
-------
table : Table
The original table
Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
Examples
--------
>>> from safeds.data.tabular.transformation import OneHotEncoder
>>> from safeds.data.tabular.containers import Table
>>> transformer = OneHotEncoder()
>>> table = Table.from_dict({"col1": [1, 2, 1], "col2": [1, 2, 4]})
>>> transformer = transformer.fit(table, None)
>>> transformed_table = transformer.transform(table)
>>> transformed_table.inverse_transform_table(transformer)
col1 col2
0 1 1
1 2 2
2 1 4
>>> transformer.inverse_transform(transformed_table)
col1 col2
0 1 1
1 2 2
2 1 4
"""
return transformer.inverse_transform(self)

# ------------------------------------------------------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import TransformerNotFittedError
from safeds.data.tabular.transformation import OneHotEncoder


class TestInverseTransformTableOnOneHotEncoder:
@pytest.mark.parametrize(
("table_to_fit", "column_names", "table_to_transform"),
[
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"c": [0.0, 0.0, 0.0, 1.0],
"b": ["a", "b", "b", "c"],
"a": [1.0, 0.0, 0.0, 0.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
["b", "bb"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
),
],
ids=[
"same table to fit and transform",
"different tables to fit and transform",
"one column name is a prefix of another column name",
],
)
def test_should_return_original_table(
self,
table_to_fit: Table,
column_names: list[str],
table_to_transform: Table,
) -> None:
transformer = OneHotEncoder().fit(table_to_fit, column_names)
transformed_table = transformer.transform(table_to_transform)

result = transformed_table.inverse_transform_table(transformer)

# This checks whether the columns are in the same order
assert result.column_names == table_to_transform.column_names
# This is subsumed by the next assertion, but we get a better error message
assert result.schema == table_to_transform.schema
assert result == table_to_transform

def test_should_not_change_transformed_table(self) -> None:
table = Table.from_dict(
{
"col1": ["a", "b", "b", "c"],
},
)

transformer = OneHotEncoder().fit(table, None)
transformed_table = transformer.transform(table)
transformed_table.inverse_transform_table(transformer)

expected = Table.from_dict(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
},
)

assert transformed_table == expected

def test_should_raise_if_not_fitted(self) -> None:
table = Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": [0.0, 1.0, 1.0, 0.0],
"c": [0.0, 0.0, 0.0, 1.0],
},
)

transformer = OneHotEncoder()

with pytest.raises(TransformerNotFittedError):
table.inverse_transform_table(transformer)

0 comments on commit 846bf23

Please sign in to comment.