-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: metrics as methods of models (#77)
Closes #64. ### Summary of Changes Metrics are now methods of classifiers and regressors. They also take a validation or test set as input now instead of two columns representing predicated and expected values. --------- Co-authored-by: lars-reimann <[email protected]>
- Loading branch information
1 parent
ec539eb
commit bc63693
Showing
17 changed files
with
204 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
21 changes: 0 additions & 21 deletions
21
src/safeds/ml/classification/metrics/_module_level_functions.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
58 changes: 0 additions & 58 deletions
58
src/safeds/ml/regression/metrics/_module_level_functions.py
This file was deleted.
Oops, something went wrong.
File renamed without changes.
26 changes: 26 additions & 0 deletions
26
tests/safeds/ml/classification/_classifier/_dummy_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from safeds.data.tabular.containers import Table, TaggedTable | ||
from safeds.ml.classification import Classifier | ||
|
||
|
||
class DummyClassifier(Classifier): | ||
""" | ||
Dummy classifier to test metrics. | ||
Metrics methods expect a `TaggedTable` as input with two columns: | ||
- `predicted`: The predicted targets. | ||
- `expected`: The correct targets. | ||
`target_name` must be set to `"expected"`. | ||
""" | ||
|
||
def fit(self, training_set: TaggedTable) -> None: | ||
pass | ||
|
||
def predict(self, dataset: Table) -> TaggedTable: | ||
# Needed until https://github.com/Safe-DS/Stdlib/issues/75 is fixed | ||
predicted = dataset.get_column("predicted") | ||
feature = predicted.rename("feature") | ||
dataset = Table.from_columns([feature, predicted]) | ||
|
||
return TaggedTable(dataset, target_name="predicted") |
20 changes: 20 additions & 0 deletions
20
tests/safeds/ml/classification/_classifier/test_accuracy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pandas as pd | ||
from safeds.data.tabular.containers import Column, Table, TaggedTable | ||
|
||
from ._dummy_classifier import DummyClassifier | ||
|
||
|
||
def test_accuracy() -> None: | ||
c1 = Column(pd.Series(data=[1, 2, 3, 4]), "predicted") | ||
c2 = Column(pd.Series(data=[1, 2, 3, 3]), "expected") | ||
table = TaggedTable(Table.from_columns([c1, c2]), target_name="expected") | ||
|
||
assert DummyClassifier().accuracy(table) == 0.75 | ||
|
||
|
||
def test_accuracy_different_types() -> None: | ||
c1 = Column(pd.Series(data=["1", "2", "3", "4"]), "predicted") | ||
c2 = Column(pd.Series(data=[1, 2, 3, 3]), "expected") | ||
table = TaggedTable(Table.from_columns([c1, c2]), target_name="expected") | ||
|
||
assert DummyClassifier().accuracy(table) == 0.0 |
15 changes: 0 additions & 15 deletions
15
tests/safeds/ml/classification/metrics/_accuracy/test_accuracy.py
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from safeds.data.tabular.containers import Table, TaggedTable | ||
from safeds.ml.regression import Regressor | ||
|
||
|
||
class DummyRegressor(Regressor): | ||
""" | ||
Dummy regressor to test metrics. | ||
Metrics methods expect a `TaggedTable` as input with two columns: | ||
- `predicted`: The predicted targets. | ||
- `expected`: The correct targets. | ||
`target_name` must be set to `"expected"`. | ||
""" | ||
|
||
def fit(self, training_set: TaggedTable) -> None: | ||
pass | ||
|
||
def predict(self, dataset: Table) -> TaggedTable: | ||
# Needed until https://github.com/Safe-DS/Stdlib/issues/75 is fixed | ||
predicted = dataset.get_column("predicted") | ||
feature = predicted.rename("feature") | ||
dataset = Table.from_columns([feature, predicted]) | ||
|
||
return TaggedTable(dataset, target_name="predicted") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
tests/safeds/ml/regression/_regressor/test_mean_absolute_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pytest | ||
from safeds.data.tabular.containers import Column, Table, TaggedTable | ||
|
||
from ._dummy_regressor import DummyRegressor | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"predicted, expected, result", | ||
[ | ||
([1, 2], [1, 2], 0), | ||
([0, 0], [1, 1], 1), | ||
([1, 1, 1], [2, 2, 11], 4), | ||
([0, 0, 0], [10, 2, 18], 10), | ||
([0.5, 0.5], [1.5, 1.5], 1), | ||
], | ||
) | ||
def test_mean_absolute_error_valid( | ||
predicted: list[float], expected: list[float], result: float | ||
) -> None: | ||
predicted_column = Column(predicted, "predicted") | ||
expected_column = Column(expected, "expected") | ||
table = TaggedTable( | ||
Table.from_columns([predicted_column, expected_column]), target_name="expected" | ||
) | ||
|
||
assert DummyRegressor().mean_absolute_error(table) == result |
20 changes: 20 additions & 0 deletions
20
tests/safeds/ml/regression/_regressor/test_mean_squared_error.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
from safeds.data.tabular.containers import Column, Table, TaggedTable | ||
|
||
from ._dummy_regressor import DummyRegressor | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"predicted, expected, result", | ||
[([1, 2], [1, 2], 0), ([0, 0], [1, 1], 1), ([1, 1, 1], [2, 2, 11], 34)], | ||
) | ||
def test_mean_squared_error_valid( | ||
predicted: list[float], expected: list[float], result: float | ||
) -> None: | ||
predicted_column = Column(predicted, "predicted") | ||
expected_column = Column(expected, "expected") | ||
table = TaggedTable( | ||
Table.from_columns([predicted_column, expected_column]), target_name="expected" | ||
) | ||
|
||
assert DummyRegressor().mean_squared_error(table) == result |
22 changes: 0 additions & 22 deletions
22
tests/safeds/ml/regression/metrics/test_mean_absolute_error.py
This file was deleted.
Oops, something went wrong.
16 changes: 0 additions & 16 deletions
16
tests/safeds/ml/regression/metrics/test_mean_squared_error.py
This file was deleted.
Oops, something went wrong.