Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add KNearestNeighborsImputer #864

Merged
merged 32 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
54ec649
KNN imputer implemenmted
SamanHushi Jun 21, 2024
ba42201
modified __init__
SamanHushi Jun 21, 2024
e80a651
added tests and change a bit
SamanHushi Jun 21, 2024
36640fd
more and better test
SamanHushi Jun 21, 2024
6450e3d
removed typechecking for init
SamanHushi Jun 21, 2024
83f7d92
end of day
SamanHushi Jun 21, 2024
a8e9fd9
wrote all tests and everything working accordingly
LIEeOoNn Jun 25, 2024
aa69dce
renamed a test and removed a wrong todo
LIEeOoNn Jun 28, 2024
8dfb8c4
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
d4375b7
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
054fba2
how should we test the __hash__ function?
LIEeOoNn Jun 28, 2024
1c1bfa4
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
959a250
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
f866aff
removed unreachable code
SamanHushi Jun 28, 2024
18a8e62
Merge branch 'main' into 743-add-knearestneighborsimputer
SamanHushi Jun 28, 2024
d568a1f
added missing word in Knn discription
SamanHushi Jun 28, 2024
b704c03
adjusted tests
SamanHushi Jun 28, 2024
1412a12
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
f6f1974
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
9e68c09
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
7a5a454
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
6f5d4ed
Update tests/safeds/data/tabular/transformation/test_k_nearest_neighb…
SamanHushi Jun 28, 2024
e436937
added '_check_bounds' implementation
SamanHushi Jun 28, 2024
ef72b77
added neighbor_count to all tests
SamanHushi Jun 28, 2024
c4caca7
should have 100% conver now and hashing implemented like in SimpleImp…
LIEeOoNn Jul 1, 2024
21f3d0c
style: apply automated linter fixes
megalinter-bot Jul 1, 2024
64dc4a8
Merge branch 'main' into 743-add-knearestneighborsimputer
lars-reimann Jul 1, 2024
55d4a71
added property value_to_replace changed nan into fit and the import also
LIEeOoNn Jul 2, 2024
3b670ca
removed the import of nan into the if statement
LIEeOoNn Jul 2, 2024
3fc7a62
style: apply automated linter fixes
megalinter-bot Jul 2, 2024
489d329
now using var: value_to_replace for correct usage_
LIEeOoNn Jul 2, 2024
c63833a
style: apply automated linter fixes
megalinter-bot Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/safeds/data/tabular/transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
if TYPE_CHECKING:
from ._discretizer import Discretizer
from ._invertible_table_transformer import InvertibleTableTransformer
from ._k_nearest_neighbors_imputer import KNearestNeighborsImputer
from ._label_encoder import LabelEncoder
from ._one_hot_encoder import OneHotEncoder
from ._range_scaler import RangeScaler
Expand All @@ -25,6 +26,7 @@
"SimpleImputer": "._simple_imputer:SimpleImputer",
"StandardScaler": "._standard_scaler:StandardScaler",
"TableTransformer": "._table_transformer:TableTransformer",
"KNearestNeighborsImputer": "._k_nearest_neighbors_imputer:KNearestNeighborsImputer",
},
)

Expand All @@ -37,4 +39,5 @@
"SimpleImputer",
"StandardScaler",
"TableTransformer",
"KNearestNeighborsImputer",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds._validation import _check_columns_exist
from safeds.data.tabular.containers import Table
from safeds.exceptions import TransformerNotFittedError

from ._table_transformer import TableTransformer

if TYPE_CHECKING:
from sklearn.impute import KNNImputer as sk_KNNImputer


class KNearestNeighborsImputer(TableTransformer):
"""
The KNearestNeighborsImputer replaces missing values in a with the mean value of the K-nearest neighbors.
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
neighbor_count:
The number of neighbors to consider when imputing missing values.
column_names:
The list of columns used to impute missing values. If 'None', all columns are used.
"""

# ------------------------------------------------------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __init__(
self,
neighbor_count: int = 5,
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved
*,
column_names: str | list[str] | None = None,
value_to_replace: float | str | None = None,
) -> None:
super().__init__(column_names)

if neighbor_count <= 0:
raise ValueError('Parameter "neighbor_count" must be greater than 0.')
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

# parameter
self._neighbor_count: int = neighbor_count
self._value_to_replace: float | str | None = value_to_replace

# attributes
self._wrapped_transformer: sk_KNNImputer | None = None

def __hash__(self) -> int:
return _structural_hash(self)

Check warning on line 52 in src/safeds/data/tabular/transformation/_k_nearest_neighbors_imputer.py

View check run for this annotation

Codecov / codecov/patch

src/safeds/data/tabular/transformation/_k_nearest_neighbors_imputer.py#L52

Added line #L52 was not covered by tests

# ------------------------------------------------------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------------------------------------------------------

@property
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""
return self._wrapped_transformer is not None

@property
def neighbor_count(self) -> int:
"""The number of neighbors to consider when imputing missing values."""
return self._neighbor_count

LIEeOoNn marked this conversation as resolved.
Show resolved Hide resolved
# ------------------------------------------------------------------------------------------------------------------
# Learning and transformation
# ------------------------------------------------------------------------------------------------------------------

def fit(self, table: Table) -> KNearestNeighborsImputer:
"""
Learn a trandformation for a set of columns in a table.
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

This transformer is not modified.
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
table:
The table used to fit the transformer.

Returns
-------
fitted_transformer:
The fitted transformer.

Raises
------
ColumnNotFoundError
If one of the columns, that should be fitted is not in the table.
"""
from sklearn.impute import KNNImputer as sk_KNNImputer

if table.row_count == 0:
raise ValueError("The KNearestNeighborsImputer cannot be fitted because the table contains 0 rows.")

if self._column_names is None:
column_names = table.column_names
else:
column_names = self._column_names
_check_columns_exist(table, column_names)

wrapped_transformer = sk_KNNImputer(n_neighbors=self._neighbor_count, missing_values=self._value_to_replace)
wrapped_transformer.set_output(transform="polars")
wrapped_transformer.fit(
table.remove_columns_except(column_names)._data_frame,
)

result = KNearestNeighborsImputer(self._neighbor_count, column_names=column_names)
result._wrapped_transformer = wrapped_transformer

return result

def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.

The Table is not modified.
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
table:
The table to wich the learned transformation is applied.

Returns
-------
transformed_table:
The transformed table.

Raises
------
TransformerNotFittedError
If the transformer is not fitted.
ColumnNotFoundError
If one of the columns, that should be transformed is not in the table.
"""
if self._column_names is None or self._wrapped_transformer is None:
raise TransformerNotFittedError

_check_columns_exist(table, self._column_names)

new_data = self._wrapped_transformer.transform(
table.remove_columns_except(self._column_names)._data_frame,
)

return Table._from_polars_lazy_frame(
table._lazy_frame.with_columns(new_data),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import numpy as np
import pytest
from safeds.data.tabular.containers import Table
from safeds.data.tabular.transformation import KNearestNeighborsImputer
from safeds.exceptions import ColumnNotFoundError, TransformerNotFittedError


class TestInit:
def test_should_raise_value_error(self) -> None:
with pytest.raises(ValueError, match='Parameter "neighbor_count" must be greater than 0.'):
_ = KNearestNeighborsImputer(neighbor_count=0)
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved

def test_neighbor_count(self) -> None:
knn = KNearestNeighborsImputer(neighbor_count=5)
assert knn.neighbor_count == 5


class TestFit:
def test_should_raise_if_column_not_found(self) -> None:
table = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

with pytest.raises(ColumnNotFoundError):
KNearestNeighborsImputer(column_names=["col2", "col3"]).fit(table)

def test_should_raise_if_table_contains_no_rows(self) -> None:
with pytest.raises(
ValueError,
match=r"The KNearestNeighborsImputer cannot be fitted because the table contains 0 rows",
):
KNearestNeighborsImputer().fit(Table({"col1": []}))

def test_should_not_change_original_transformer(self) -> None:
table = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

transformer = KNearestNeighborsImputer()
transformer.fit(table)

assert transformer._column_names is None
assert transformer._wrapped_transformer is None


class TestTransform:
def test_should_raise_if_column_not_found(self) -> None:
table_to_fit = Table(
{
"col1": [0.0, 5.0, 10.0],
"col2": [5.0, 50.0, 100.0],
},
)

transformer = KNearestNeighborsImputer()

table_to_transform = Table(
{
"col3": ["a", "b", "c"],
},
)

with pytest.raises(ColumnNotFoundError):
transformer.fit(table_to_fit).transform(table_to_transform)

def test_should_raise_if_not_fitted(self) -> None:
table = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

transformer = KNearestNeighborsImputer()

with pytest.raises(TransformerNotFittedError):
transformer.transform(table)


class TestIsFitted:
def test_should_return_false_before_fitting(self) -> None:
transformer = KNearestNeighborsImputer()
assert not transformer.is_fitted

def test_should_return_true_after_fitting(self) -> None:
table = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

transformer = KNearestNeighborsImputer()
fitted_transformer = transformer.fit(table)
assert fitted_transformer.is_fitted


class TestFitAndTransform:
@pytest.mark.parametrize(
("table", "column_names", "expected"),
[
(
Table(
{
"col1": [1, 2, np.nan],
"col2": [1, 2, 3],
},
),
["col1"],
Table(
{
"col1": [1, 2, 2], # Assuming k=1, the nearest neighbor for the missing value is 2.
"col2": [1, 2, 3],
},
),
),
(
Table(
{
"col1": [1, 2, np.nan, 4],
"col2": [1, 2, 3, 4],
},
),
["col1"],
Table(
{
"col1": [1, 2, 2, 4], # Assuming k=1, the nearest neighbor for the missing value is 2.
"col2": [1, 2, 3, 4],
},
),
),
],
ids=["one_column", "two_columns"],
)
def test_should_return_fitted_transformer_and_transformed_table(
self,
table: Table,
column_names: list[str] | None, # noqa: ARG002
expected: Table,
) -> None:
fitted_transformer, transformed_table = KNearestNeighborsImputer(
neighbor_count=1,
column_names=None,
value_to_replace=np.nan,
).fit_and_transform(table)
assert fitted_transformer.is_fitted
assert transformed_table == expected

@pytest.mark.parametrize(
("table", "column_names", "expected"),
[
(
Table(
{
"col1": [1, 2, np.nan],
"col2": [1, 2, 3],
},
),
["col1"],
Table(
{
"col1": [1, 2, 3 / 2], # Assuming k=1, the nearest neighbor for the missing value is 1.5
"col2": [1, 2, 3],
},
),
),
(
Table(
{
"col1": [1, 2, np.nan, 4],
"col2": [1, np.nan, 3, 4],
},
),
["col1"],
Table(
{
"col1": [1, 2, 7 / 3, 4], # Assuming k=1, the nearest neighbor for the missing value is 2.
"col2": [1, 8 / 3, 3, 4],
},
),
),
],
ids=["one_column", "two_columns"],
)
def test_should_return_fitted_transformer_and_transformed_table_with_correct_values(
SamanHushi marked this conversation as resolved.
Show resolved Hide resolved
self,
table: Table,
column_names: list[str] | None, # noqa: ARG002
expected: Table,
) -> None:
fitted_transformer, transformed_table = KNearestNeighborsImputer(
neighbor_count=3,
value_to_replace=np.nan,
).fit_and_transform(table)
assert fitted_transformer.is_fitted
assert transformed_table == expected

def test_should_not_change_original_table(self) -> None:
table = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

KNearestNeighborsImputer().fit_and_transform(table)

expected = Table(
{
"col1": [0.0, 5.0, 10.0],
},
)

assert table == expected