Skip to content

Commit

Permalink
Fix bug when assigning "no prediction" false negative flags (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntlind authored Aug 23, 2024
1 parent c542ce4 commit 759f8fe
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 2 deletions.
266 changes: 266 additions & 0 deletions core/tests/functional-tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,3 +1245,269 @@ def test_compute_classification(
)
== 0
)


def test_pr_curves_multiple_predictions_per_groundtruth():
"""Test that we get back the expected results when creating PR curves with multiple predictions per groundtruth."""
groundtruths = [
schemas.GroundTruth(
datum=schemas.Datum(uid="datum0", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label", value="cat", score=None
)
],
)
],
),
schemas.GroundTruth(
datum=schemas.Datum(uid="datum1", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label", value="bee", score=None
)
],
)
],
),
schemas.GroundTruth(
datum=schemas.Datum(uid="datum2", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label", value="cat", score=None
)
],
)
],
),
schemas.GroundTruth(
datum=schemas.Datum(uid="datum3", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label", value="bee", score=None
)
],
)
],
),
schemas.GroundTruth(
datum=schemas.Datum(uid="datum4", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label", value="dog", score=None
)
],
)
],
),
]
predictions = [
schemas.Prediction(
datum=schemas.Datum(uid="datum0", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label",
value="cat",
score=0.44598543489942505,
),
schemas.Label(
key="class_label",
value="dog",
score=0.3255517969601126,
),
schemas.Label(
key="class_label",
value="bee",
score=0.22846276814046224,
),
],
)
],
),
schemas.Prediction(
datum=schemas.Datum(uid="datum1", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label",
value="cat",
score=0.4076893257212283,
),
schemas.Label(
key="class_label",
value="dog",
score=0.14780458563955237,
),
schemas.Label(
key="class_label",
value="bee",
score=0.4445060886392194,
),
],
)
],
),
schemas.Prediction(
datum=schemas.Datum(uid="datum2", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label",
value="cat",
score=0.25060075263871917,
),
schemas.Label(
key="class_label",
value="dog",
score=0.3467428086425673,
),
schemas.Label(
key="class_label",
value="bee",
score=0.4026564387187136,
),
],
)
],
),
schemas.Prediction(
datum=schemas.Datum(uid="datum3", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label",
value="cat",
score=0.2003514145616792,
),
schemas.Label(
key="class_label",
value="dog",
score=0.2485912151889644,
),
schemas.Label(
key="class_label",
value="bee",
score=0.5510573702493565,
),
],
)
],
),
schemas.Prediction(
datum=schemas.Datum(uid="datum4", metadata=None),
annotations=[
schemas.Annotation(
labels=[
schemas.Label(
key="class_label",
value="cat",
score=0.33443897813714385,
),
schemas.Label(
key="class_label",
value="dog",
score=0.5890646197236098,
),
schemas.Label(
key="class_label",
value="bee",
score=0.07649640213924616,
),
],
)
],
),
]

eval_job = evaluate_classification(
groundtruths=groundtruths,
predictions=predictions,
metrics_to_return=[enums.MetricType.PrecisionRecallCurve],
)

output = eval_job.metrics[0]["value"]

# there are two cat, two bee, and one dog groundtruths
# once we raise the score threshold above the maximum score, we expect the tps to become fns and the fps to become tns

# start by testing bee
def _get_specific_keys_from_pr_output(output_dict):
return {
k: v
for k, v in output_dict.items()
if k in ["tp", "fp", "tn", "fn"]
}

assert _get_specific_keys_from_pr_output(output["bee"][0.05]) == {
"tp": 2.0,
"fp": 3.0,
"fn": 0.0,
"tn": 0.0,
}
assert _get_specific_keys_from_pr_output(output["bee"][0.55]) == {
"tp": 1.0,
"fp": 0.0,
"fn": 1.0,
"tn": 3.0,
}
assert _get_specific_keys_from_pr_output(output["bee"][0.95]) == {
"tp": 0.0,
"fp": 0.0,
"fn": 2.0,
"tn": 3.0,
}

# cat
assert _get_specific_keys_from_pr_output(output["cat"][0.05]) == {
"tp": 2.0,
"fp": 3.0,
"fn": 0.0,
"tn": 0.0,
}
assert _get_specific_keys_from_pr_output(output["cat"][0.40]) == {
"tp": 1.0,
"fp": 1.0,
"fn": 1.0,
"tn": 2.0,
}
assert _get_specific_keys_from_pr_output(output["cat"][0.95]) == {
"tp": 0.0,
"fp": 0.0,
"fn": 2.0,
"tn": 3.0,
}

# dog
assert _get_specific_keys_from_pr_output(output["dog"][0.05]) == {
"tp": 1.0,
"fp": 4.0,
"fn": 0.0,
"tn": 0.0,
}
assert _get_specific_keys_from_pr_output(output["dog"][0.55]) == {
"tp": 1.0,
"fp": 0.0,
"fn": 0.0,
"tn": 4.0,
}
assert _get_specific_keys_from_pr_output(output["dog"][0.95]) == {
"tp": 0.0,
"fp": 0.0,
"fn": 1.0,
"tn": 4.0,
}
4 changes: 2 additions & 2 deletions core/valor_core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,11 @@ def _calculate_pr_curves(
confidence_interval_to_misclassification_fn_groundtruth_ids_dict.items()
):
threshold_mask = pr_calc_df["confidence_threshold"] == threshold
membership_mask = ~pr_calc_df["id_gt"].isin(elements)
membership_mask = pr_calc_df["id_gt"].isin(elements)
mask |= threshold_mask & membership_mask

pr_calc_df["no_predictions_false_negative_flag"] = (
mask & pr_calc_df["false_negative_flag"]
~mask & pr_calc_df["false_negative_flag"]
)

else:
Expand Down

0 comments on commit 759f8fe

Please sign in to comment.