Skip to content

Commit

Permalink
quality_control: correct the per-class accuracy formula
Browse files Browse the repository at this point in the history
The current formula used to calculate `ConfusionMatrix.accuracy` is, in fact,
not accuracy, but the Jaccard index. Replace it with the correct formula.

Since the Jaccard index is a useful metric in its own right, calculate it too,
but save it in another attribute of `ConfusionMatrix`.
  • Loading branch information
SpecLad committed Mar 19, 2024
1 parent 53bf350 commit f09cb74
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions cvat/apps/quality_control/quality_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class ConfusionMatrix(_Serializable):
precision: np.array
recall: np.array
accuracy: np.array
jaccard_index: np.array

@property
def axes(self):
Expand All @@ -240,6 +241,7 @@ def from_dict(cls, d: dict):
precision=np.asarray(d["precision"]),
recall=np.asarray(d["recall"]),
accuracy=np.asarray(d["accuracy"]),
jaccard_index=np.asarray(d["jaccard_index"]),
)


Expand Down Expand Up @@ -1934,17 +1936,23 @@ def _generate_annotations_summary(
matched_ann_counts = np.diag(confusion_matrix)
ds_ann_counts = np.sum(confusion_matrix, axis=1)
gt_ann_counts = np.sum(confusion_matrix, axis=0)
total_annotations_count = np.sum(confusion_matrix)

label_accuracies = _arr_div(
label_jaccard_indices = _arr_div(
matched_ann_counts, ds_ann_counts + gt_ann_counts - matched_ann_counts
)
label_precisions = _arr_div(matched_ann_counts, ds_ann_counts)
label_recalls = _arr_div(matched_ann_counts, gt_ann_counts)
label_accuracies = (
total_annotations_count # TP + TN + FP + FN
- (ds_ann_counts - matched_ann_counts) # - FP
- (gt_ann_counts - matched_ann_counts) # - FN
# ... = TP + TN
) / (total_annotations_count or 1)

valid_annotations_count = np.sum(matched_ann_counts)
missing_annotations_count = np.sum(confusion_matrix[cls._UNMATCHED_IDX, :])
extra_annotations_count = np.sum(confusion_matrix[:, cls._UNMATCHED_IDX])
total_annotations_count = np.sum(confusion_matrix)
ds_annotations_count = np.sum(ds_ann_counts[: cls._UNMATCHED_IDX])
gt_annotations_count = np.sum(gt_ann_counts[: cls._UNMATCHED_IDX])

Expand All @@ -1961,6 +1969,7 @@ def _generate_annotations_summary(
precision=label_precisions,
recall=label_recalls,
accuracy=label_accuracies,
jaccard_index=label_jaccard_indices,
),
)

Expand Down

0 comments on commit f09cb74

Please sign in to comment.