Skip to content

Commit

Permalink
feat: discretize table (#327)
Browse files Browse the repository at this point in the history
Closes #143.

### Summary of Changes

* Added a class `Discretizer` in `safeds.data.tabular.transformation`
that wraps the [`KBinsDiscretizer` of
`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html)
* Made the class a subclass of `TableTransformer`
* The `__init__` for now only has a parameter `number_of_bins` to
control how many bins are created
* If `number_of_bins` is less than 2, it raises a `ValueError`
  • Loading branch information
robmeth authored Jul 7, 2023
1 parent 388ab2d commit 5e3da8d
Show file tree
Hide file tree
Showing 3 changed files with 508 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/safeds/data/tabular/transformation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for transforming tabular data."""

from ._discretizer import Discretizer
from ._imputer import Imputer
from ._label_encoder import LabelEncoder
from ._one_hot_encoder import OneHotEncoder
Expand All @@ -14,5 +15,6 @@
"InvertibleTableTransformer",
"TableTransformer",
"RangeScaler",
"Discretizer",
"StandardScaler",
]
205 changes: 205 additions & 0 deletions src/safeds/data/tabular/transformation/_discretizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from __future__ import annotations

from sklearn.preprocessing import KBinsDiscretizer as sk_KBinsDiscretizer

from safeds.data.tabular.containers import Table
from safeds.data.tabular.transformation._table_transformer import TableTransformer
from safeds.exceptions import NonNumericColumnError, TransformerNotFittedError, UnknownColumnNameError


class Discretizer(TableTransformer):
"""
The Discretizer bins continuous data into intervals.
Parameters
----------
number_of_bins: float
The number of bins to be created.
Raises
------
ValueError
If the given number_of_bins is less than 2.
"""

def __init__(self, number_of_bins: float = 5):
self._column_names: list[str] | None = None
self._wrapped_transformer: sk_KBinsDiscretizer | None = None

if number_of_bins < 2:
raise ValueError("Parameter 'number_of_bins' must be >= 2.")
self._number_of_bins = number_of_bins

def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
"""
Learn a transformation for a set of columns in a table.
This transformer is not modified.
Parameters
----------
table : Table
The table used to fit the transformer.
column_names : list[str] | None
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
Returns
-------
fitted_transformer : TableTransformer
The fitted transformer.
Raises
------
ValueError
If the table is empty.
NonNumericColumnError
If one of the columns, that should be fitted is non-numeric.
UnknownColumnNameError
If one of the columns, that should be fitted is not in the table.
"""
if table.number_of_rows == 0:
raise ValueError("The Discretizer cannot be fitted because the table contains 0 rows")

if column_names is None:
column_names = table.column_names
else:
missing_columns = set(column_names) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(
sorted(
missing_columns,
key={val: ix for ix, val in enumerate(column_names)}.__getitem__,
),
)

for column in column_names:
if not table.get_column(column).type.is_numeric():
raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.")

wrapped_transformer = sk_KBinsDiscretizer(n_bins=self._number_of_bins, encode="ordinal")
wrapped_transformer.fit(table._data[column_names])

result = Discretizer(self._number_of_bins)
result._wrapped_transformer = wrapped_transformer
result._column_names = column_names

return result

def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.
The table is not modified.
Parameters
----------
table : Table
The table to which the learned transformation is applied.
Returns
-------
transformed_table : Table
The transformed table.
Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
ValueError
If the table is empty.
UnknownColumnNameError
If one of the columns, that should be transformed is not in the table.
NonNumericColumnError
If one of the columns, that should be fitted is non-numeric.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise TransformerNotFittedError

if table.number_of_rows == 0:
raise ValueError("The table cannot be transformed because it contains 0 rows")

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(
sorted(
missing_columns,
key={val: ix for ix, val in enumerate(self._column_names)}.__getitem__,
),
)

for column in self._column_names:
if not table.get_column(column).type.is_numeric():
raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.")

data = table._data.copy()
data.columns = table.column_names
data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names])
return Table._from_pandas_dataframe(data)

def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.
Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None

def get_names_of_added_columns(self) -> list[str]:
"""
Get the names of all new columns that have been added by the Discretizer.
Returns
-------
added_columns : list[str]
A list of names of the added columns, ordered as they will appear in the table.
Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if not self.is_fitted():
raise TransformerNotFittedError
return []

# (Must implement abstract method, cannot instantiate class otherwise.)
def get_names_of_changed_columns(self) -> list[str]:
"""
Get the names of all columns that may have been changed by the Discretizer.
Returns
-------
changed_columns : list[str]
The list of (potentially) changed column names, as passed to fit.
Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if self._column_names is None:
raise TransformerNotFittedError
return self._column_names

def get_names_of_removed_columns(self) -> list[str]:
"""
Get the names of all columns that have been removed by the Discretizer.
Returns
-------
removed_columns : list[str]
A list of names of the removed columns, ordered as they appear in the table the Discretizer was fitted on.
Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if not self.is_fitted():
raise TransformerNotFittedError
return []
Loading

0 comments on commit 5e3da8d

Please sign in to comment.