Skip to content

Commit

Permalink
feat: precision metric for classification (#272)
Browse files Browse the repository at this point in the history
Closes #185.

### Summary of Changes
Added precision function in _classifier.

---------

Co-authored-by: Lars Reimann <[email protected]>
Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: patrikguempel <[email protected]>
Co-authored-by: [email protected]
  • Loading branch information
4 people authored May 5, 2023
1 parent 0d7a998 commit 5adadad
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 3 deletions.
44 changes: 41 additions & 3 deletions src/safeds/ml/classical/classification/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,45 @@ def accuracy(self, validation_or_test_set: TaggedTable) -> float:
"""
if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table):
raise UntaggedTableError
expected = validation_or_test_set.target
predicted = self.predict(validation_or_test_set.features).target

return sk_accuracy_score(expected._data, predicted._data)
expected_values = validation_or_test_set.target
predicted_values = self.predict(validation_or_test_set.features).target

return sk_accuracy_score(expected_values._data, predicted_values._data)

def precision(self, validation_or_test_set: TaggedTable, positive_class: int = 1) -> float:
"""
Compute the classifier's precision on the given data.
Parameters
----------
validation_or_test_set : TaggedTable
The validation or test set.
positive_class : int | str
The class to be considered positive. All other classes are considered negative.
Returns
-------
precision : float
The calculated precision score, i.e. the ratio of correctly predicted positives to all predicted positives.
Return 1 if no positive predictions are made.
"""
if not isinstance(validation_or_test_set, TaggedTable) and isinstance(validation_or_test_set, Table):
raise UntaggedTableError

expected_values = validation_or_test_set.target
predicted_values = self.predict(validation_or_test_set.features).target

n_true_positives = 0
n_false_positives = 0

for expected_value, predicted_value in zip(expected_values, predicted_values, strict=True):
if predicted_value == positive_class:
if expected_value == positive_class:
n_true_positives += 1
else:
n_false_positives += 1

if (n_true_positives + n_false_positives) == 0:
return 1.0
return n_true_positives / (n_true_positives + n_false_positives)
49 changes: 49 additions & 0 deletions tests/safeds/ml/classical/classification/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,52 @@ def test_with_different_types(self) -> None:
def test_should_raise_if_table_is_not_tagged(self, table: Table) -> None:
with pytest.raises(UntaggedTableError):
DummyClassifier().accuracy(table) # type: ignore[arg-type]


class TestPrecision:
def test_should_compare_result(self) -> None:
table = Table.from_dict(
{
"predicted": [1, 1, 0, 2],
"expected": [1, 0, 1, 2],
},
).tag_columns(target_name="expected")

assert DummyClassifier().precision(table, 1) == 0.5

def test_should_compare_result_with_different_types(self) -> None:
table = Table.from_dict(
{
"predicted": [1, "1", "0", "2"],
"expected": [1, 0, 1, 2],
},
).tag_columns(target_name="expected")

assert DummyClassifier().precision(table, 1) == 1.0

def test_should_return_1_if_never_expected_to_be_positive(self) -> None:
table = Table.from_dict(
{
"predicted": ["lol", "1", "0", "2"],
"expected": [1, 0, 1, 2],
},
).tag_columns(target_name="expected")

assert DummyClassifier().precision(table, 1) == 1.0

@pytest.mark.parametrize(
"table",
[
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": [0.0, 1.0, 1.0, 0.0],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
],
ids=["untagged_table"],
)
def test_should_raise_if_table_is_not_tagged(self, table: Table) -> None:
with pytest.raises(UntaggedTableError):
DummyClassifier().precision(table) # type: ignore[arg-type]

0 comments on commit 5adadad

Please sign in to comment.