Skip to content

Commit

Permalink
feat: new method is_fitted to check whether a transformer is fitted (
Browse files Browse the repository at this point in the history
…#131)

### Summary of Changes

Add a new method `is_fitted` to `TableTransformer`s to easily check
whether they have been fitted already.

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann authored Mar 30, 2023
1 parent 8e1c3ea commit e20954f
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/safeds/data/tabular/transformation/_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,14 @@ def transform(self, table: Table) -> Table:
indices = [table.schema._get_column_index_by_name(name) for name in self._column_names]
data[indices] = pd.DataFrame(self._wrapped_transformer.transform(data[indices]), columns=indices)
return Table(data, table.schema)

def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.
Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None
11 changes: 11 additions & 0 deletions src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,14 @@ def inverse_transform(self, transformed_table: Table) -> Table:
data.columns = transformed_table.get_column_names()
data[self._column_names] = self._wrapped_transformer.inverse_transform(data[self._column_names])
return Table(data)

def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.
Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None
11 changes: 11 additions & 0 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,14 @@ def inverse_transform(self, transformed_table: Table) -> Table:
unchanged = data.drop(self._wrapped_transformer.get_feature_names_out(), axis=1)

return Table(pd.concat([unchanged, decoded], axis=1))

def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.
Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None
11 changes: 11 additions & 0 deletions src/safeds/data/tabular/transformation/_table_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def transform(self, table: Table) -> Table:
If the transformer has not been fitted yet.
"""

@abstractmethod
def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.
Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""

def fit_and_transform(self, table: Table, column_names: Optional[list[str]] = None) -> Table:
"""
Learn a transformation for a set of columns in a table and apply the learned transformation to the same table.
Expand Down
17 changes: 17 additions & 0 deletions tests/safeds/data/tabular/transformation/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def test_should_raise_if_not_fitted(self) -> None:
transformer.transform(table)


class TestIsFitted:
def test_should_return_false_before_fitting(self) -> None:
transformer = Imputer(Imputer.Strategy.Mean())
assert not transformer.is_fitted()

def test_should_return_true_after_fitting(self) -> None:
table = Table.from_columns(
[
Column("a", [1, 3, None]),
]
)

transformer = Imputer(Imputer.Strategy.Mean())
fitted_transformer = transformer.fit(table)
assert fitted_transformer.is_fitted()


class TestFitAndTransform:
@pytest.mark.parametrize(
("table", "column_names", "strategy", "expected"),
Expand Down
17 changes: 17 additions & 0 deletions tests/safeds/data/tabular/transformation/test_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def test_should_raise_if_not_fitted(self) -> None:
transformer.transform(table)


class TestIsFitted:
def test_should_return_false_before_fitting(self) -> None:
transformer = LabelEncoder()
assert not transformer.is_fitted()

def test_should_return_true_after_fitting(self) -> None:
table = Table.from_columns(
[
Column("col1", ["a", "b", "c"]),
]
)

transformer = LabelEncoder()
fitted_transformer = transformer.fit(table)
assert fitted_transformer.is_fitted()


class TestFitAndTransform:
@pytest.mark.parametrize(
("table", "column_names", "expected"),
Expand Down
17 changes: 17 additions & 0 deletions tests/safeds/data/tabular/transformation/test_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def test_should_raise_if_not_fitted(self) -> None:
transformer.transform(table)


class TestIsFitted:
def test_should_return_false_before_fitting(self) -> None:
transformer = OneHotEncoder()
assert not transformer.is_fitted()

def test_should_return_true_after_fitting(self) -> None:
table = Table.from_columns(
[
Column("col1", ["a", "b", "c"]),
]
)

transformer = OneHotEncoder()
fitted_transformer = transformer.fit(table)
assert fitted_transformer.is_fitted()


class TestFitAndTransform:
@pytest.mark.parametrize(
("table", "column_names", "expected"),
Expand Down

0 comments on commit e20954f

Please sign in to comment.