Skip to content

Commit

Permalink
feat: new method is_fitted to check whether a model is fitted (#130)
Browse files Browse the repository at this point in the history
### Summary of Changes

Add a new method `is_fitted` to `Classifier`s and `Regressor`s to easily
check whether they have been fitted already.

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann authored Mar 30, 2023
1 parent fa04186 commit 8e1c3ea
Show file tree
Hide file tree
Showing 19 changed files with 222 additions and 3 deletions.
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""

@abstractmethod
def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""

# noinspection PyProtectedMember
def accuracy(self, validation_or_test_set: TaggedTable) -> float:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_gradient_boosting_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_k_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/classification/_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Checks if the classifier is fitted.
Returns
-------
is_fitted : bool
Whether the classifier is fitted.
"""
return self._wrapped_classifier is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_elastic_net_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_gradient_boosting_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_k_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_lasso_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
15 changes: 12 additions & 3 deletions src/safeds/ml/regression/_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""

@abstractmethod
def is_fitted(self) -> bool:
"""
Checks if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""

# noinspection PyProtectedMember
def mean_squared_error(self, validation_or_test_set: TaggedTable) -> float:
"""
Expand Down Expand Up @@ -110,7 +121,5 @@ def _check_metrics_preconditions(actual: Column, expected: Column) -> None:

if actual._data.size != expected._data.size:
raise ColumnLengthMismatchError(
"\n".join(
[f"{column.name}: {column._data.size}" for column in [actual, expected]]
)
"\n".join([f"{column.name}: {column._data.size}" for column in [actual, expected]])
)
11 changes: 11 additions & 0 deletions src/safeds/ml/regression/_ridge_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ def predict(self, dataset: Table) -> TaggedTable:
If prediction with the given dataset failed.
"""
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name)

def is_fitted(self) -> bool:
"""
Check if the regressor is fitted.
Returns
-------
is_fitted : bool
Whether the regressor is fitted.
"""
return self._wrapped_regressor is not None
17 changes: 17 additions & 0 deletions tests/safeds/ml/classification/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def test_should_succeed_on_valid_data(self, classifier: Classifier, valid_data:
classifier.fit(valid_data)
assert True # This asserts that the fit method succeeds

def test_should_not_change_input_classifier(self, classifier: Classifier, valid_data: TaggedTable) -> None:
classifier.fit(valid_data)
assert not classifier.is_fitted()

def test_should_not_change_input_table(self, classifier: Classifier, request: FixtureRequest) -> None:
valid_data = request.getfixturevalue("valid_data")
valid_data_copy = request.getfixturevalue("valid_data")
Expand Down Expand Up @@ -109,6 +113,16 @@ def test_should_raise_on_invalid_data(
fitted_classifier.predict(invalid_data.features)


@pytest.mark.parametrize("classifier", classifiers(), ids=lambda x: x.__class__.__name__)
class TestIsFitted:
def test_should_return_false_before_fitting(self, classifier: Classifier) -> None:
assert not classifier.is_fitted()

def test_should_return_true_after_fitting(self, classifier: Classifier, valid_data: TaggedTable) -> None:
fitted_classifier = classifier.fit(valid_data)
assert fitted_classifier.is_fitted()


class DummyClassifier(Classifier):
"""
Dummy classifier to test metrics.
Expand All @@ -133,6 +147,9 @@ def predict(self, dataset: Table) -> TaggedTable:

return dataset.tag_columns(target_name="predicted")

def is_fitted(self) -> bool:
return True


class TestAccuracy:
def test_with_same_type(self) -> None:
Expand Down
17 changes: 17 additions & 0 deletions tests/safeds/ml/regression/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def test_should_succeed_on_valid_data(self, regressor: Regressor, valid_data: Ta
regressor.fit(valid_data)
assert True # This asserts that the fit method succeeds

def test_should_not_change_input_regressor(self, regressor: Regressor, valid_data: TaggedTable) -> None:
regressor.fit(valid_data)
assert not regressor.is_fitted()

def test_should_not_change_input_table(self, regressor: Regressor, request: FixtureRequest) -> None:
valid_data = request.getfixturevalue("valid_data")
valid_data_copy = request.getfixturevalue("valid_data")
Expand Down Expand Up @@ -125,6 +129,16 @@ def test_should_raise_on_invalid_data(
fitted_regressor.predict(invalid_data.features)


@pytest.mark.parametrize("regressor", regressors(), ids=lambda x: x.__class__.__name__)
class TestIsFitted:
def test_should_return_false_before_fitting(self, regressor: Regressor) -> None:
assert not regressor.is_fitted()

def test_should_return_true_after_fitting(self, regressor: Regressor, valid_data: TaggedTable) -> None:
fitted_regressor = regressor.fit(valid_data)
assert fitted_regressor.is_fitted()


class DummyRegressor(Regressor):
"""
Dummy regressor to test metrics.
Expand All @@ -149,6 +163,9 @@ def predict(self, dataset: Table) -> TaggedTable:

return dataset.tag_columns(target_name="predicted")

def is_fitted(self) -> bool:
return True


class TestMeanAbsoluteError:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 8e1c3ea

Please sign in to comment.