Skip to content

Commit

Permalink
feat: improve error handling for predict (#145)
Browse files Browse the repository at this point in the history
Closes #9.

### Summary of Changes

* Use appropriate exception types.
* Improve messages of raised exception.
* Improve documentation.
  • Loading branch information
lars-reimann authored Apr 2, 2023
1 parent a311727 commit a5ff11c
Show file tree
Hide file tree
Showing 31 changed files with 287 additions and 96 deletions.
6 changes: 3 additions & 3 deletions src/safeds/data/tabular/transformation/_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from safeds.data.tabular.containers import Table
from safeds.data.tabular.transformation._table_transformer import TableTransformer
from safeds.exceptions import NotFittedError, UnknownColumnNameError
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError


class ImputerStrategy(ABC):
Expand Down Expand Up @@ -139,12 +139,12 @@ def transform(self, table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise NotFittedError
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.get_column_names())
Expand Down
10 changes: 5 additions & 5 deletions src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from safeds.data.tabular.transformation._table_transformer import (
InvertibleTableTransformer,
)
from safeds.exceptions import NotFittedError, UnknownColumnNameError
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError


def warn(*_: Any, **__: Any) -> None:
Expand Down Expand Up @@ -77,12 +77,12 @@ def transform(self, table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise NotFittedError
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.get_column_names())
Expand Down Expand Up @@ -110,12 +110,12 @@ def inverse_transform(self, transformed_table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise NotFittedError
raise TransformerNotFittedError

data = transformed_table._data.copy()
data.columns = transformed_table.get_column_names()
Expand Down
12 changes: 6 additions & 6 deletions src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from safeds.data.tabular.transformation._table_transformer import (
InvertibleTableTransformer,
)
from safeds.exceptions import NotFittedError, UnknownColumnNameError
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError


class OneHotEncoder(InvertibleTableTransformer):
"""The OneHotEncoder encodes categorical columns to numerical features [0,1] that represent the existence for each value."""
"""Encodes categorical columns to numerical features [0,1] that represent the existence for each value."""

def __init__(self) -> None:
self._wrapped_transformer: sk_OneHotEncoder | None = None
Expand Down Expand Up @@ -70,12 +70,12 @@ def transform(self, table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise NotFittedError
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.get_column_names())
Expand Down Expand Up @@ -109,12 +109,12 @@ def inverse_transform(self, transformed_table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise NotFittedError
raise TransformerNotFittedError

data = transformed_table._data.copy()
data.columns = transformed_table.get_column_names()
Expand Down
4 changes: 2 additions & 2 deletions src/safeds/data/tabular/transformation/_table_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def transform(self, table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""

Expand Down Expand Up @@ -101,6 +101,6 @@ def inverse_transform(self, transformed_table: Table) -> Table:
Raises
------
NotFittedError
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
14 changes: 12 additions & 2 deletions src/safeds/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@
MissingSchemaError,
NonNumericColumnError,
SchemaMismatchError,
TransformerNotFittedError,
UnknownColumnNameError,
)
from ._learning_exceptions import LearningError, NotFittedError, PredictionError
from ._ml_exceptions import (
DatasetContainsTargetError,
DatasetMissesFeaturesError,
LearningError,
ModelNotFittedError,
PredictionError,
)

__all__ = [
"ColumnLengthMismatchError",
"ColumnSizeError",
"DatasetContainsTargetError",
"DatasetMissesFeaturesError",
"DuplicateColumnNameError",
"IndexOutOfBoundsError",
"LearningError",
"MissingDataError",
"MissingSchemaError",
"ModelNotFittedError",
"NonNumericColumnError",
"NotFittedError",
"PredictionError",
"SchemaMismatchError",
"TransformerNotFittedError",
"UnknownColumnNameError",
]
7 changes: 7 additions & 0 deletions src/safeds/exceptions/_data_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,10 @@ class MissingDataError(Exception):

def __init__(self, missing_data_info: str):
super().__init__(f"The function is missing data: \n{missing_data_info}")


class TransformerNotFittedError(Exception):
"""Raised when a transformer is used before fitting it."""

def __init__(self) -> None:
super().__init__("The transformer has not been fitted yet.")
39 changes: 0 additions & 39 deletions src/safeds/exceptions/_learning_exceptions.py

This file was deleted.

67 changes: 67 additions & 0 deletions src/safeds/exceptions/_ml_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
class DatasetContainsTargetError(ValueError):
"""
Raised when a dataset contains the target column already.
Parameters
----------
target_name: str
The name of the target column.
"""

def __init__(self, target_name: str):
super().__init__(f"Dataset already contains the target column '{target_name}'.")


class DatasetMissesFeaturesError(ValueError):
"""
Raised when a dataset misses feature columns.
Parameters
----------
missing_feature_names: list[str]
The names of the missing feature columns.
"""

def __init__(self, missing_feature_names: list[str]):
super().__init__(f"Dataset misses the feature columns '{missing_feature_names}'.")


class LearningError(Exception):
"""
Raised when an error occurred while training a model.
Parameters
----------
reason: str | None
The reason for the error.
"""

def __init__(self, reason: str | None):
if reason is None:
super().__init__("Error occurred while learning")
else:
super().__init__(f"Error occurred while learning: {reason}")


class ModelNotFittedError(Exception):
"""Raised when a model is used before fitting it."""

def __init__(self) -> None:
super().__init__("The model has not been fitted yet.")


class PredictionError(Exception):
"""
Raised when an error occurred while prediction a target vector using a model.
Parameters
----------
reason: str | None
The reason for the error.
"""

def __init__(self, reason: str | None):
if reason is None:
super().__init__("Error occurred while predicting")
else:
super().__init__(f"Error occurred while predicting: {reason}")
26 changes: 19 additions & 7 deletions src/safeds/ml/_util_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any

from sklearn.exceptions import NotFittedError

from safeds.data.tabular.containers import Table, TaggedTable
from safeds.exceptions import LearningError, PredictionError
from safeds.exceptions import (
DatasetContainsTargetError,
DatasetMissesFeaturesError,
LearningError,
ModelNotFittedError,
PredictionError,
)


# noinspection PyProtectedMember
Expand Down Expand Up @@ -55,13 +59,23 @@ def predict(model: Any, dataset: Table, feature_names: list[str] | None, target_
Raises
------
ModelNotFittedError
If the model has not been fitted yet.
DatasetContainsTargetError
If the dataset contains the target column already.
DatasetMissesFeaturesError
If the dataset misses feature columns.
PredictionError
If predicting with the given dataset failed.
"""
# Validation
if model is None or target_name is None or feature_names is None:
raise PredictionError("The model has not been trained yet.")
raise ModelNotFittedError
if dataset.has_column(target_name):
raise ValueError(f"Dataset already contains the target column '{target_name}'.")
raise DatasetContainsTargetError(target_name)
missing_feature_names = [feature_name for feature_name in feature_names if not dataset.has_column(feature_name)]
if missing_feature_names:
raise DatasetMissesFeaturesError(missing_feature_names)

dataset_df = dataset.keep_only_columns(feature_names)._data
dataset_df.columns = feature_names
Expand All @@ -73,7 +87,5 @@ def predict(model: Any, dataset: Table, feature_names: list[str] | None, target_
predicted_target_vector = model.predict(dataset_df.values)
result_set[target_name] = predicted_target_vector
return Table(result_set).tag_columns(target_name=target_name, feature_names=feature_names)
except NotFittedError as exception:
raise PredictionError("The model was not trained") from exception
except ValueError as exception:
raise PredictionError(str(exception)) from exception
8 changes: 7 additions & 1 deletion src/safeds/ml/classification/_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ def predict(self, dataset: Table) -> TaggedTable:
Raises
------
ModelNotFittedError
If the model has not been fitted yet.
DatasetContainsTargetError
If the dataset contains the target column already.
DatasetMissesFeaturesError
If the dataset misses feature columns.
PredictionError
If prediction with the given dataset failed.
If predicting with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

Expand Down
8 changes: 7 additions & 1 deletion src/safeds/ml/classification/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@ def predict(self, dataset: Table) -> TaggedTable:
Raises
------
ModelNotFittedError
If the model has not been fitted yet.
DatasetContainsTargetError
If the dataset contains the target column already.
DatasetMissesFeaturesError
If the dataset misses feature columns.
PredictionError
If prediction with the given dataset failed.
If predicting with the given dataset failed.
"""

@abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion src/safeds/ml/classification/_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ def predict(self, dataset: Table) -> TaggedTable:
Raises
------
ModelNotFittedError
If the model has not been fitted yet.
DatasetContainsTargetError
If the dataset contains the target column already.
DatasetMissesFeaturesError
If the dataset misses feature columns.
PredictionError
If prediction with the given dataset failed.
If predicting with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def fit(self, training_set: TaggedTable) -> GradientBoosting:

return result

# noinspection PyProtectedMember
def predict(self, dataset: Table) -> TaggedTable:
"""
Predict a target vector using a dataset containing feature vectors. The model has to be trained first.
Expand All @@ -68,8 +67,14 @@ def predict(self, dataset: Table) -> TaggedTable:
Raises
------
ModelNotFittedError
If the model has not been fitted yet.
DatasetContainsTargetError
If the dataset contains the target column already.
DatasetMissesFeaturesError
If the dataset misses feature columns.
PredictionError
If prediction with the given dataset failed.
If predicting with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

Expand Down
Loading

0 comments on commit a5ff11c

Please sign in to comment.