From c542ce4fb660eb6e9a60c9d737d7f2d1c8172e46 Mon Sep 17 00:00:00 2001 From: Nick L Date: Fri, 23 Aug 2024 08:14:24 -0600 Subject: [PATCH] Add test_evaluate_detection_fp to `valor_core` (#718) --- core/tests/functional-tests/test_detection.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/core/tests/functional-tests/test_detection.py b/core/tests/functional-tests/test_detection.py index 7f89fd4cb..d05316607 100644 --- a/core/tests/functional-tests/test_detection.py +++ b/core/tests/functional-tests/test_detection.py @@ -4427,3 +4427,159 @@ def test_two_groundtruths_one_datum( assert result_dict["meta"]["labels"] == 2 assert result_dict["meta"]["annotations"] == 5 assert result_dict["meta"]["duration"] <= 5 + + +def test_evaluate_detection_pr_fp(img1, img2): + gts = [ + schemas.GroundTruth( + datum=img1, + annotations=[ + schemas.Annotation( + is_instance=True, + labels=[schemas.Label(key="k1", value="v1")], + bounding_box=schemas.Box.from_extrema( + xmin=0, xmax=5, ymin=0, ymax=5 + ), + ) + ], + ), + schemas.GroundTruth( + datum=img2, + annotations=[ + schemas.Annotation( + is_instance=True, + labels=[schemas.Label(key="k1", value="v1")], + bounding_box=schemas.Box.from_extrema( + xmin=0, xmax=5, ymin=0, ymax=5 + ), + ) + ], + ), + ] + preds = [ + schemas.Prediction( + datum=img1, + annotations=[ + schemas.Annotation( + is_instance=True, + labels=[schemas.Label(key="k1", value="v1", score=0.8)], + bounding_box=schemas.Box.from_extrema( + xmin=0, xmax=5, ymin=0, ymax=5 + ), + ) + ], + ), + schemas.Prediction( + datum=img2, + annotations=[ + schemas.Annotation( + is_instance=True, + labels=[schemas.Label(key="k1", value="v1", score=0.8)], + bounding_box=schemas.Box.from_extrema( + xmin=10, xmax=20, ymin=10, ymax=20 + ), + ) + ], + ), + ] + + eval_job = evaluate_detection( + groundtruths=gts, + predictions=preds, + metrics_to_return=[ + enums.MetricType.PrecisionRecallCurve, + ], + ) + + metrics = eval_job.metrics + assert metrics[0]["value"]["v1"][0.5] == { + "fn": 1, # img2 + "fp": 1, # img2 + "tn": None, + "tp": 1, # img1 + "recall": 0.5, + "accuracy": None, + "f1_score": 0.5, + "precision": 0.5, + } + + # score threshold is now higher than the scores, so we should the predictions drop out such that we're only left with 2 fns (one for each image) + assert metrics[0]["value"]["v1"][0.85] == { + "tp": 0, + "fp": 0, + "fn": 2, + "tn": None, + "precision": 0.0, + "recall": 0.0, + "accuracy": None, + "f1_score": 0.0, + } + + # test DetailedPRCurve version + eval_job = evaluate_detection( + groundtruths=gts, + predictions=preds, + metrics_to_return=[ + enums.MetricType.DetailedPrecisionRecallCurve, + ], + ) + + metrics = eval_job.metrics + + score_threshold = 0.5 + assert metrics[0]["value"]["v1"][score_threshold]["tp"]["total"] == 1 + assert "tn" not in metrics[0]["value"]["v1"][score_threshold] + assert ( + metrics[0]["value"]["v1"][score_threshold]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 1 + ) + assert ( + metrics[0]["value"]["v1"][score_threshold]["fp"]["observations"][ + "misclassifications" + ]["count"] + == 0 + ) + assert ( + metrics[0]["value"]["v1"][score_threshold]["fn"]["observations"][ + "no_predictions" + ]["count"] + == 1 + ) + assert metrics[0]["value"]["v1"][score_threshold]["tp"]["total"] == 1 + assert ( + metrics[0]["value"]["v1"][score_threshold]["fn"]["observations"][ + "misclassifications" + ]["count"] + == 0 + ) + + # score threshold is now higher than the scores, so we should the predictions drop out such that we're only left with 2 fns (one for each image) + score_threshold = 0.85 + assert metrics[0]["value"]["v1"][score_threshold]["tp"]["total"] == 0 + assert "tn" not in metrics[0]["value"]["v1"][score_threshold] + assert ( + metrics[0]["value"]["v1"][score_threshold]["fp"]["observations"][ + "hallucinations" + ]["count"] + == 0 + ) + assert ( + metrics[0]["value"]["v1"][score_threshold]["fp"]["observations"][ + "misclassifications" + ]["count"] + == 0 + ) + assert ( + metrics[0]["value"]["v1"][score_threshold]["fn"]["observations"][ + "no_predictions" + ]["count"] + == 2 + ) + assert ( + metrics[0]["value"]["v1"][score_threshold]["fn"]["observations"][ + "misclassifications" + ]["count"] + == 0 + )