Skip to content

Commit

Permalink
feat: OneHotEncoder.inverse_transform now maintains the column orde…
Browse files Browse the repository at this point in the history
…r from the original table (#195)

Closes #109.

### Summary of Changes

`OneHotEncoder.inverse_transform` now maintains the column order from
the original table (#109)
Fixed bug with `OneHotEncoder.inverse_transform` to not work if not all
columns were fitted
New feature columns in `OneHotEncoder` will now be inserted where the
combined columns were in the original table

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
3 people authored Apr 18, 2023
1 parent bea976a commit 3ec0041
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 17 deletions.
52 changes: 43 additions & 9 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class OneHotEncoder(InvertibleTableTransformer):

def __init__(self) -> None:
self._wrapped_transformer: sk_OneHotEncoder | None = None
self._column_names: list[str] | None = None
self._column_names: dict[str, list[str]] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None = None) -> OneHotEncoder:
Expand Down Expand Up @@ -49,7 +49,10 @@ def fit(self, table: Table, column_names: list[str] | None = None) -> OneHotEnco

result = OneHotEncoder()
result._wrapped_transformer = wrapped_transformer
result._column_names = column_names
result._column_names = {
column: [f"{column}_{element}" for element in table.get_column(column).get_unique_values()]
for column in column_names
}

return result

Expand Down Expand Up @@ -78,19 +81,33 @@ def transform(self, table: Table) -> Table:
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.get_column_names())
missing_columns = set(self._column_names.keys()) - set(table.get_column_names())
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

original = table._data.copy()
original.columns = table.schema.get_column_names()

one_hot_encoded = pd.DataFrame(self._wrapped_transformer.transform(original[self._column_names]).toarray())
one_hot_encoded = pd.DataFrame(
self._wrapped_transformer.transform(original[self._column_names.keys()]).toarray(),
)
one_hot_encoded.columns = self._wrapped_transformer.get_feature_names_out()

unchanged = original.drop(self._column_names, axis=1)
unchanged = original.drop(self._column_names.keys(), axis=1)

res = Table(pd.concat([unchanged, one_hot_encoded], axis=1))
column_names = []

for name in table.get_column_names():
if name not in self._column_names.keys():
column_names.append(name)
else:
column_names.extend(
[f_name for f_name in self._wrapped_transformer.get_feature_names_out() if f_name.startswith(name)],
)
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return Table(pd.concat([unchanged, one_hot_encoded], axis=1))
return res

# noinspection PyProtectedMember
def inverse_transform(self, transformed_table: Table) -> Table:
Expand Down Expand Up @@ -120,12 +137,29 @@ def inverse_transform(self, transformed_table: Table) -> Table:
data.columns = transformed_table.get_column_names()

decoded = pd.DataFrame(
self._wrapped_transformer.inverse_transform(transformed_table._data),
columns=self._column_names,
self._wrapped_transformer.inverse_transform(
transformed_table.keep_only_columns(self._wrapped_transformer.get_feature_names_out())._data,
),
columns=list(self._column_names.keys()),
)
unchanged = data.drop(self._wrapped_transformer.get_feature_names_out(), axis=1)

return Table(pd.concat([unchanged, decoded], axis=1))
res = Table(pd.concat([unchanged, decoded], axis=1))
column_names = [
name
if name not in [value for value_list in list(self._column_names.values()) for value in value_list]
else list(self._column_names.keys())[
[
list(self._column_names.values()).index(value)
for value in list(self._column_names.values())
if name in value
][0]
]
for name in transformed_table.get_column_names()
]
res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name))

return res

def is_fitted(self) -> bool:
"""
Expand Down
98 changes: 90 additions & 8 deletions tests/safeds/data/tabular/transformation/test_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,34 @@ class TestFitAndTransform:
["col1"],
Table.from_dict(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col2": ["a", "b", "b", "c"],
},
),
),
(
Table.from_dict(
{
"col1": ["a", "b", "b", "c"],
"col2": ["a", "b", "b", "c"],
},
),
["col1", "col2"],
Table.from_dict(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
"col2_a": [1.0, 0.0, 0.0, 0.0],
"col2_b": [0.0, 1.0, 1.0, 0.0],
"col2_c": [0.0, 0.0, 0.0, 1.0],
},
),
),
],
ids=["all columns", "one column", "multiple columns"],
)
def test_should_return_transformed_table(
self,
Expand Down Expand Up @@ -144,19 +164,81 @@ def test_should_not_change_original_table(self) -> None:

class TestInverseTransform:
@pytest.mark.parametrize(
"table",
("table_to_fit", "column_names", "table_to_transform"),
[
Table.from_dict(
{
"col1": ["a", "b", "b", "c"],
},
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"c": [0.0, 0.0, 0.0, 1.0],
"b": ["a", "b", "b", "c"],
"a": [1.0, 0.0, 0.0, 0.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
["b", "bb"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
),
],
ids=[
"same table to fit and transform",
"different tables to fit and transform",
"one column name is a prefix of another column name",
],
)
def test_should_return_original_table(self, table: Table) -> None:
transformer = OneHotEncoder().fit(table)
def test_should_return_original_table(
self,
table_to_fit: Table,
column_names: list[str],
table_to_transform: Table,
) -> None:
transformer = OneHotEncoder().fit(table_to_fit, column_names)

result = transformer.inverse_transform(transformer.transform(table_to_transform))

assert transformer.inverse_transform(transformer.transform(table)) == table
# This checks whether the columns are in the same order
assert result.get_column_names() == table_to_transform.get_column_names()
# This is subsumed by the next assertion, but we get a better error message
assert result.schema == table_to_transform.schema
assert result == table_to_transform

def test_should_not_change_transformed_table(self) -> None:
table = Table.from_dict(
Expand Down

0 comments on commit 3ec0041

Please sign in to comment.