Skip to content

Commit

Permalink
Lite Classification Accuracy Fix (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Oct 15, 2024
1 parent 95f44b3 commit 491bbb6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 115 deletions.
97 changes: 7 additions & 90 deletions lite/tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,12 @@ def test_accuracy_computation():
)

# score threshold, label, count metric
assert accuracy.shape == (2, 4)
assert accuracy.shape == (2,)

# score >= 0.25
assert accuracy[0][0] == 2 / 3
assert accuracy[0][1] == 1.0
assert accuracy[0][2] == 2 / 3
assert accuracy[0][3] == 1.0
assert accuracy[0] == 2 / 3
# score >= 0.75
assert accuracy[1][0] == 2 / 3
assert accuracy[1][1] == 1.0
assert accuracy[1][2] == 2 / 3
assert accuracy[1][3] == 2 / 3
assert accuracy[1] == 1 / 3


def test_accuracy_basic(basic_classifications: list[Classification]):
Expand All @@ -87,20 +81,10 @@ def test_accuracy_basic(basic_classifications: list[Classification]):
expected_metrics = [
{
"type": "Accuracy",
"value": [2 / 3, 2 / 3],
"value": [2 / 3, 1 / 3],
"parameters": {
"score_thresholds": [0.25, 0.75],
"hardmax": True,
"label": "0",
},
},
{
"type": "Accuracy",
"value": [1.0, 2 / 3],
"parameters": {
"score_thresholds": [0.25, 0.75],
"hardmax": True,
"label": "3",
},
},
]
Expand All @@ -124,29 +108,10 @@ def test_accuracy_with_animal_example(
expected_metrics = [
{
"type": "Accuracy",
"value": [2.0 / 3.0],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "bird",
},
},
{
"type": "Accuracy",
"value": [0.5],
"value": [2.0 / 6.0],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "dog",
},
},
{
"type": "Accuracy",
"value": [2 / 3],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "cat",
},
},
]
Expand All @@ -170,38 +135,10 @@ def test_accuracy_color_example(
expected_metrics = [
{
"type": "Accuracy",
"value": [2 / 3],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "white",
},
},
{
"type": "Accuracy",
"value": [2 / 3],
"value": [2 / 6],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "red",
},
},
{
"type": "Accuracy",
"value": [2 / 3],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "blue",
},
},
{
"type": "Accuracy",
"value": [5 / 6],
"parameters": {
"score_thresholds": [0.5],
"hardmax": True,
"label": "black",
},
},
]
Expand Down Expand Up @@ -237,7 +174,6 @@ def test_accuracy_with_image_example(
"parameters": {
"score_thresholds": [0.0],
"hardmax": True,
"label": "v4",
},
},
]
Expand Down Expand Up @@ -269,29 +205,10 @@ def test_accuracy_with_tabular_example(
expected_metrics = [
{
"type": "Accuracy",
"value": [0.7],
"parameters": {
"score_thresholds": [0.0],
"hardmax": True,
"label": "0",
},
},
{
"type": "Accuracy",
"value": [0.5],
"parameters": {
"score_thresholds": [0.0],
"hardmax": True,
"label": "1",
},
},
{
"type": "Accuracy",
"value": [0.8],
"value": [5 / 10],
"parameters": {
"score_thresholds": [0.0],
"hardmax": True,
"label": "2",
},
},
]
Expand Down
4 changes: 2 additions & 2 deletions lite/valor_lite/classification/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def compute_metrics(
out=precision,
)

accuracy = np.zeros_like(recall)
accuracy = np.zeros(n_scores, dtype=np.float64)
np.divide(
(counts[:, :, 0] + counts[:, :, 3]),
counts[:, :, 0].sum(axis=1),
float(n_datums),
out=accuracy,
)
Expand Down
14 changes: 8 additions & 6 deletions lite/valor_lite/classification/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ def compute_precision_recall(
)
]

metrics[MetricType.Accuracy] = [
Accuracy(
value=accuracy.astype(float).tolist(),
score_thresholds=score_thresholds,
hardmax=hardmax,
)
]

for label_idx, label in self.index_to_label.items():

kwargs = {
Expand Down Expand Up @@ -401,12 +409,6 @@ def compute_precision_recall(
**kwargs,
)
)
metrics[MetricType.Accuracy].append(
Accuracy(
value=accuracy[:, label_idx].astype(float).tolist(),
**kwargs,
)
)
metrics[MetricType.F1].append(
F1(
value=f1_score[:, label_idx].astype(float).tolist(),
Expand Down
46 changes: 29 additions & 17 deletions lite/valor_lite/classification/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,23 @@ class Recall(_ThresholdValue):
pass


class Accuracy(_ThresholdValue):
class F1(_ThresholdValue):
"""
Accuracy metric for a specific class label.
F1 score for a specific class label.
This class calculates the accuracy at various score thresholds for a binary
classification task. Accuracy is defined as the ratio of the sum of true positives and
true negatives over all predictions.
This class calculates the F1 score at various score thresholds for a binary
classification task.
Attributes
----------
value : list[float]
Accuracy values computed at each score threshold.
F1 scores computed at each score threshold.
score_thresholds : list[float]
Score thresholds at which the accuracy values are computed.
Score thresholds at which the F1 scores are computed.
hardmax : bool
Indicates whether hardmax thresholding was used.
label : str
The class label for which the accuracy is computed.
The class label for which the F1 score is computed.
Methods
-------
Expand All @@ -188,23 +187,21 @@ class Accuracy(_ThresholdValue):
pass


class F1(_ThresholdValue):
@dataclass
class Accuracy:
"""
F1 score for a specific class label.
Multiclass accuracy metric.
This class calculates the F1 score at various score thresholds for a binary
classification task.
This class calculates the accuracy at various score thresholds.
Attributes
----------
value : list[float]
F1 scores computed at each score threshold.
Accuracy values computed at each score threshold.
score_thresholds : list[float]
Score thresholds at which the F1 scores are computed.
Score thresholds at which the accuracy values are computed.
hardmax : bool
Indicates whether hardmax thresholding was used.
label : str
The class label for which the F1 score is computed.
Methods
-------
Expand All @@ -214,7 +211,22 @@ class F1(_ThresholdValue):
Converts the instance to a dictionary representation.
"""

pass
value: list[float]
score_thresholds: list[float]
hardmax: bool

def to_metric(self) -> Metric:
return Metric(
type=type(self).__name__,
value=self.value,
parameters={
"score_thresholds": self.score_thresholds,
"hardmax": self.hardmax,
},
)

def to_dict(self) -> dict:
return self.to_metric().to_dict()


@dataclass
Expand Down

0 comments on commit 491bbb6

Please sign in to comment.