Skip to content

Commit

Permalink
test: improve ml tests (#129)
Browse files Browse the repository at this point in the history
Closes #116.
Closes #117.

### Summary of Changes

* Reduce code duplication
* Test that `fit` doesn't change the input table
* Test that `predict` doesn't change the input table

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann authored Mar 30, 2023
1 parent ddd3f59 commit fa04186
Show file tree
Hide file tree
Showing 17 changed files with 222 additions and 1,068 deletions.
71 changes: 0 additions & 71 deletions tests/safeds/ml/classification/test_ada_boost.py

This file was deleted.

106 changes: 105 additions & 1 deletion tests/safeds/ml/classification/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,112 @@
from __future__ import annotations

import pandas as pd
import pytest
from _pytest.fixtures import FixtureRequest
from safeds.data.tabular.containers import Column, Table, TaggedTable
from safeds.ml.classification import Classifier
from safeds.exceptions import LearningError, PredictionError
from safeds.ml.classification import (
AdaBoost,
Classifier,
DecisionTree,
GradientBoosting,
KNearestNeighbors,
LogisticRegression,
RandomForest,
)


def classifiers() -> list[Classifier]:
"""
Returns the list of classifiers to test.
After you implemented a new classifier, add it to this list to ensure its `fit` and `predict` method work as
expected. Place tests of methods that are specific to your classifier in a separate test file.
Returns
-------
classifiers : list[Classifier]
The list of classifiers to test.
"""

return [AdaBoost(), DecisionTree(), GradientBoosting(), KNearestNeighbors(2), LogisticRegression(), RandomForest()]


@pytest.fixture()
def valid_data() -> TaggedTable:
return Table.from_columns(
[
Column("id", [1, 4]),
Column("feat1", [2, 5]),
Column("feat2", [3, 6]),
Column("target", [0, 1]),
]
).tag_columns(target_name="target", feature_names=["feat1", "feat2"])


@pytest.fixture()
def invalid_data() -> TaggedTable:
return Table.from_columns(
[
Column("id", [1, 4]),
Column("feat1", ["a", 5]),
Column("feat2", [3, 6]),
Column("target", [0, 1]),
]
).tag_columns(target_name="target", feature_names=["feat1", "feat2"])


@pytest.mark.parametrize("classifier", classifiers(), ids=lambda x: x.__class__.__name__)
class TestFit:
def test_should_succeed_on_valid_data(self, classifier: Classifier, valid_data: TaggedTable) -> None:
classifier.fit(valid_data)
assert True # This asserts that the fit method succeeds

def test_should_not_change_input_table(self, classifier: Classifier, request: FixtureRequest) -> None:
valid_data = request.getfixturevalue("valid_data")
valid_data_copy = request.getfixturevalue("valid_data")
classifier.fit(valid_data)
assert valid_data == valid_data_copy

def test_should_raise_on_invalid_data(self, classifier: Classifier, invalid_data: TaggedTable) -> None:
with pytest.raises(LearningError):
classifier.fit(invalid_data)


@pytest.mark.parametrize("classifier", classifiers(), ids=lambda x: x.__class__.__name__)
class TestPredict:
def test_should_include_features_of_input_table(self, classifier: Classifier, valid_data: TaggedTable) -> None:
fitted_classifier = classifier.fit(valid_data)
prediction = fitted_classifier.predict(valid_data.features)
assert prediction.features == valid_data.features

def test_should_include_complete_input_table(self, classifier: Classifier, valid_data: TaggedTable) -> None:
fitted_regressor = classifier.fit(valid_data)
prediction = fitted_regressor.predict(valid_data.remove_columns(["target"]))
assert prediction.remove_columns(["target"]) == valid_data.remove_columns(["target"])

def test_should_set_correct_target_name(self, classifier: Classifier, valid_data: TaggedTable) -> None:
fitted_classifier = classifier.fit(valid_data)
prediction = fitted_classifier.predict(valid_data.features)
assert prediction.target.name == "target"

def test_should_not_change_input_table(self, classifier: Classifier, request: FixtureRequest) -> None:
valid_data = request.getfixturevalue("valid_data")
valid_data_copy = request.getfixturevalue("valid_data")
fitted_classifier = classifier.fit(valid_data)
fitted_classifier.predict(valid_data.features)
assert valid_data == valid_data_copy

def test_should_raise_when_not_fitted(self, classifier: Classifier, valid_data: TaggedTable) -> None:
with pytest.raises(PredictionError):
classifier.predict(valid_data.features)

def test_should_raise_on_invalid_data(
self, classifier: Classifier, valid_data: TaggedTable, invalid_data: TaggedTable
) -> None:
fitted_classifier = classifier.fit(valid_data)
with pytest.raises(PredictionError):
fitted_classifier.predict(invalid_data.features)


class DummyClassifier(Classifier):
Expand Down
71 changes: 0 additions & 71 deletions tests/safeds/ml/classification/test_decision_tree.py

This file was deleted.

71 changes: 0 additions & 71 deletions tests/safeds/ml/classification/test_gradient_boosting.py

This file was deleted.

Loading

0 comments on commit fa04186

Please sign in to comment.