From 2c4a430087d2b7a48c780a568747c8bda4d1e41a Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 4 Sep 2024 17:41:45 -0600 Subject: [PATCH 1/3] add additional checks to make sure we aren't returning any examples when count == 0 --- core/tests/conftest_inputs.py | 46 +++++++++++++++++++ core/tests/conftest_outputs.py | 15 ++++++ .../functional-tests/test_classification.py | 27 +++++++++++ core/valor_core/classification.py | 42 ++++++++++++----- 4 files changed, 119 insertions(+), 11 deletions(-) diff --git a/core/tests/conftest_inputs.py b/core/tests/conftest_inputs.py index 0a4c7cd11..04e6f19de 100644 --- a/core/tests/conftest_inputs.py +++ b/core/tests/conftest_inputs.py @@ -3238,6 +3238,52 @@ def multiclass_pr_curve_predictions(): ] +@pytest.fixture +def multiclass_pr_curve_check_zero_count_examples_groundtruths(): + return [ + schemas.GroundTruth( + datum=schemas.Datum( + uid="uid0", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="k", value="ant"), + ], + ), + ], + ), + ] + + +@pytest.fixture +def multiclass_pr_curve_check_zero_count_examples_predictions(): + return [ + schemas.Prediction( + datum=schemas.Datum( + uid="uid0", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="k", value="ant", score=0.15), + schemas.Label(key="k", value="bee", score=0.48), + schemas.Label(key="k", value="cat", score=0.37), + ], + ) + ], + ), + ] + + @pytest.fixture def evaluate_detection_false_negatives_single_image_baseline_inputs(): groundtruths = [ diff --git a/core/tests/conftest_outputs.py b/core/tests/conftest_outputs.py index 5c1d259fa..aefc31cdc 100644 --- a/core/tests/conftest_outputs.py +++ b/core/tests/conftest_outputs.py @@ -2892,6 +2892,21 @@ def detailed_curve_examples_output(): ("datum2",), ("datum0",), }, + # check cases where we shouldn't have any examples since the count is zero + ("bee", 0.3, "fn", "misclassifications"): set(), + ("dog", 0.1, "tn", "all"): set(), + } + + return expected_outputs + + +@pytest.fixture +def detailed_curve_examples_check_zero_count_examples_output(): + expected_outputs = { + ("ant", 0.05, "fp", "misclassifications"): 0, + ("ant", 0.95, "tn", "all"): 0, + ("bee", 0.2, "fn", "misclassifications"): 0, + ("cat", 0.2, "fn", "misclassifications"): 0, } return expected_outputs diff --git a/core/tests/functional-tests/test_classification.py b/core/tests/functional-tests/test_classification.py index 87d497067..1bb5db11e 100644 --- a/core/tests/functional-tests/test_classification.py +++ b/core/tests/functional-tests/test_classification.py @@ -601,8 +601,11 @@ def _get_specific_keys_from_pr_output(output_dict): def test_detailed_curve_examples( multiclass_pr_curve_groundtruths: list, + multiclass_pr_curve_check_zero_count_examples_groundtruths: list, multiclass_pr_curve_predictions: list, + multiclass_pr_curve_check_zero_count_examples_predictions: list, detailed_curve_examples_output: dict, + detailed_curve_examples_check_zero_count_examples_output: dict, ): """Test that we get back the right examples in DetailedPRCurves.""" @@ -624,3 +627,27 @@ def test_detailed_curve_examples( ) == expected ) + + # test additional cases to make sure that we aren't returning examples where count == 0 + eval_job = evaluate_classification( + groundtruths=multiclass_pr_curve_check_zero_count_examples_groundtruths, + predictions=multiclass_pr_curve_check_zero_count_examples_predictions, + metrics_to_return=[enums.MetricType.DetailedPrecisionRecallCurve], + ) + output = eval_job.metrics[0]["value"] + + for ( + key, + expected, + ) in detailed_curve_examples_check_zero_count_examples_output.items(): + assert ( + len( + output[key[0]][key[1]][key[2]]["observations"][key[3]][ + "examples" + ] + ) + == expected + ) + assert ( + output[key[0]][key[1]][key[2]]["observations"][key[3]]["count"] + ) == 0 diff --git a/core/valor_core/classification.py b/core/valor_core/classification.py index 9b944cb83..4028a2a4f 100644 --- a/core/valor_core/classification.py +++ b/core/valor_core/classification.py @@ -1005,7 +1005,11 @@ def _calculate_pr_curves( "observations": { "all": { "count": row["true_positives"], - "examples": row["true_positive_flag_samples"], + "examples": ( + row["true_positive_flag_samples"] + if row["true_positives"] + else [] + ), } }, }, @@ -1014,7 +1018,11 @@ def _calculate_pr_curves( "observations": { "all": { "count": row["true_negatives"], - "examples": row["true_negative_flag_samples"], + "examples": ( + row["true_negative_flag_samples"] + if row["true_negatives"] + else [] + ), } }, }, @@ -1023,15 +1031,23 @@ def _calculate_pr_curves( "observations": { "misclassifications": { "count": row["misclassification_false_negatives"], - "examples": row[ - "misclassification_false_negative_flag_samples" - ], + "examples": ( + row[ + "misclassification_false_negative_flag_samples" + ] + if row["misclassification_false_negatives"] + else [] + ), }, "no_predictions": { "count": row["no_predictions_false_negatives"], - "examples": row[ - "no_predictions_false_negative_flag_samples" - ], + "examples": ( + row[ + "no_predictions_false_negative_flag_samples" + ] + if row["no_predictions_false_negatives"] + else [] + ), }, }, }, @@ -1040,9 +1056,13 @@ def _calculate_pr_curves( "observations": { "misclassifications": { "count": row["misclassification_false_positives"], - "examples": row[ - "misclassification_false_positive_flag_samples" - ], + "examples": ( + row[ + "misclassification_false_positive_flag_samples" + ] + if row["misclassification_false_positives"] + else [] + ), }, }, }, From c33ae51a3ed64451154338f2012f789c10f28393 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 5 Sep 2024 01:13:37 -0600 Subject: [PATCH 2/3] fix edge case where some true negative examples weren't included in output --- core/tests/conftest_inputs.py | 84 +++++++++++++++++++ core/tests/conftest_outputs.py | 18 ++++ .../functional-tests/test_classification.py | 28 +++++++ core/valor_core/classification.py | 71 ++++++++++++++-- 4 files changed, 194 insertions(+), 7 deletions(-) diff --git a/core/tests/conftest_inputs.py b/core/tests/conftest_inputs.py index 04e6f19de..453e25652 100644 --- a/core/tests/conftest_inputs.py +++ b/core/tests/conftest_inputs.py @@ -3284,6 +3284,90 @@ def multiclass_pr_curve_check_zero_count_examples_predictions(): ] +@pytest.fixture +def multiclass_pr_curve_check_true_negatives_groundtruths(): + return [ + schemas.GroundTruth( + datum=schemas.Datum( + uid="uid0", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="dataset1", value="ant"), + ], + ), + ], + ), + schemas.GroundTruth( + datum=schemas.Datum( + uid="uid1", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="dataset2", value="egg"), + ], + ), + ], + ), + ] + + +@pytest.fixture +def multiclass_pr_curve_check_true_negatives_predictions(): + return [ + schemas.Prediction( + datum=schemas.Datum( + uid="uid0", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="dataset1", value="ant", score=0.15), + schemas.Label(key="dataset1", value="bee", score=0.48), + schemas.Label(key="dataset1", value="cat", score=0.37), + ], + ) + ], + ), + schemas.Prediction( + datum=schemas.Datum( + uid="uid1", + metadata={ + "height": 900, + "width": 300, + }, + ), + annotations=[ + schemas.Annotation( + labels=[ + schemas.Label(key="dataset2", value="egg", score=0.15), + schemas.Label( + key="dataset2", value="milk", score=0.48 + ), + schemas.Label( + key="dataset2", value="flour", score=0.37 + ), + ], + ) + ], + ), + ] + + @pytest.fixture def evaluate_detection_false_negatives_single_image_baseline_inputs(): groundtruths = [ diff --git a/core/tests/conftest_outputs.py b/core/tests/conftest_outputs.py index aefc31cdc..fca463a76 100644 --- a/core/tests/conftest_outputs.py +++ b/core/tests/conftest_outputs.py @@ -2910,3 +2910,21 @@ def detailed_curve_examples_check_zero_count_examples_output(): } return expected_outputs + + +@pytest.fixture +def detailed_curve_examples_check_true_negatives_output(): + expected_outputs = { + ("bee", 0.05, "tn", "all"): { + ("uid1",), + }, + ("bee", 0.15, "tn", "all"): { + ("uid1",), + }, + ("bee", 0.95, "tn", "all"): { + ("uid1",), + ("uid0",), + }, + } + + return expected_outputs diff --git a/core/tests/functional-tests/test_classification.py b/core/tests/functional-tests/test_classification.py index 1bb5db11e..4ccc683be 100644 --- a/core/tests/functional-tests/test_classification.py +++ b/core/tests/functional-tests/test_classification.py @@ -602,10 +602,13 @@ def _get_specific_keys_from_pr_output(output_dict): def test_detailed_curve_examples( multiclass_pr_curve_groundtruths: list, multiclass_pr_curve_check_zero_count_examples_groundtruths: list, + multiclass_pr_curve_check_true_negatives_groundtruths: list, multiclass_pr_curve_predictions: list, multiclass_pr_curve_check_zero_count_examples_predictions: list, + multiclass_pr_curve_check_true_negatives_predictions: list, detailed_curve_examples_output: dict, detailed_curve_examples_check_zero_count_examples_output: dict, + detailed_curve_examples_check_true_negatives_output: dict, ): """Test that we get back the right examples in DetailedPRCurves.""" @@ -651,3 +654,28 @@ def test_detailed_curve_examples( assert ( output[key[0]][key[1]][key[2]]["observations"][key[3]]["count"] ) == 0 + + # test additional cases to make sure that we're getting back enough true negative examples + eval_job = evaluate_classification( + groundtruths=multiclass_pr_curve_check_true_negatives_groundtruths, + predictions=multiclass_pr_curve_check_true_negatives_predictions, + metrics_to_return=[enums.MetricType.DetailedPrecisionRecallCurve], + pr_curve_max_examples=5, + ) + output = eval_job.metrics[0]["value"] + + for ( + key, + expected, + ) in detailed_curve_examples_check_true_negatives_output.items(): + import pdb + + pdb.set_trace() + assert ( + set( + output[key[0]][key[1]][key[2]]["observations"][key[3]][ + "examples" + ] + ) + == expected + ) diff --git a/core/valor_core/classification.py b/core/valor_core/classification.py index 4028a2a4f..1cbfeec3b 100644 --- a/core/valor_core/classification.py +++ b/core/valor_core/classification.py @@ -619,10 +619,11 @@ def _add_samples_to_dataframe( pr_calc_df: pd.DataFrame, max_examples: int, flag_column: str, + true_negative_datum_uids: pd.DataFrame, ) -> pd.DataFrame: """Efficienctly gather samples for a given flag.""" - if flag_column in ["no_predictions_false_negative_flag"]: + if flag_column == "no_predictions_false_negative_flag": sample_df = ( pr_calc_df[pr_calc_df[flag_column]] .groupby( @@ -686,12 +687,43 @@ def _add_samples_to_dataframe( pr_curve_counts_df[f"{flag_column}_samples"] = pr_curve_counts_df[ f"{flag_column}_samples" ].apply(lambda x: list(x) if isinstance(x, set) else list()) - else: pr_curve_counts_df[f"{flag_column}_samples"] = [ list() for _ in range(len(pr_curve_counts_df)) ] + # for true negative examples, we also need to consider examples where a label key doesn't exist on a datum (so there won't be any rows in pr_calc_df for that datum) + if flag_column == "true_negative_flag": + true_negative_datum_uids.columns = [ + "label_key", + "confidence_threshold", + "true_negative_flag_samples", + ] + pr_curve_counts_df = pr_curve_counts_df.merge( + true_negative_datum_uids, + on=[ + "label_key", + "confidence_threshold", + ], + suffixes=("", "_temp"), + ) + pr_curve_counts_df[ + "true_negative_flag_samples" + ] = pr_curve_counts_df.apply( + lambda row: ( + [ + x + for x in row["true_negative_flag_samples"] + + row["true_negative_flag_samples_temp"] + if len(x) > 0 + ] + )[:max_examples], + axis=1, + ) + del pr_curve_counts_df["true_negative_flag_samples_temp"] + + return pr_curve_counts_df + return pr_curve_counts_df @@ -813,6 +845,33 @@ def _calculate_pr_curves( "false_negative_flag" ] + # find all unique datums for use when identifying true negatives + unique_datum_uids = set(pr_calc_df["datum_uid"].unique()) + + true_negative_datum_uids: pd.DataFrame = ( + pr_calc_df[ + pr_calc_df["true_positive_flag"] + | pr_calc_df["misclassification_false_negative_flag"] + | pr_calc_df["no_predictions_false_negative_flag"] + | pr_calc_df["misclassification_false_positive_flag"] + ] + .groupby(["label_key", "confidence_threshold"], as_index=False)[ + "datum_uid" + ] + .apply(set) + ) # type: ignore - pyright thinks this output is a Series, when really it's a dataframe + + true_negative_datum_uids["datum_uid"] = ( + unique_datum_uids - true_negative_datum_uids["datum_uid"] + ).apply( # type: ignore - pandas can handle subtracting a pd.Series from a set + lambda x: [tuple(x)][:pr_curve_max_examples] + ) + true_negative_datum_uids.columns = [ + "label_key", + "confidence_threshold", + "true_negative_datum_uids", + ] + pr_calc_df["true_negative_flag"] = ( ~pr_calc_df["is_label_match"] & ~pr_calc_df["misclassification_false_positive_flag"] @@ -928,9 +987,6 @@ def _calculate_pr_curves( # we're doing an outer join, so any nulls should be zeroes pr_curve_counts_df.fillna(0, inplace=True) - # find all unique datums for use when identifying true negatives - unique_datum_ids = set(pr_calc_df["datum_id"].unique()) - # calculate additional metrics pr_curve_counts_df["false_positives"] = pr_curve_counts_df[ "misclassification_false_positives" @@ -939,7 +995,7 @@ def _calculate_pr_curves( pr_curve_counts_df["misclassification_false_negatives"] + pr_curve_counts_df["no_predictions_false_negatives"] ) - pr_curve_counts_df["true_negatives"] = len(unique_datum_ids) - ( + pr_curve_counts_df["true_negatives"] = len(unique_datum_uids) - ( pr_curve_counts_df["true_positives"] + pr_curve_counts_df["false_positives"] + pr_curve_counts_df["false_negatives"] @@ -955,7 +1011,7 @@ def _calculate_pr_curves( pr_curve_counts_df["accuracy"] = ( pr_curve_counts_df["true_positives"] + pr_curve_counts_df["true_negatives"] - ) / len(unique_datum_ids) + ) / len(unique_datum_uids) pr_curve_counts_df["f1_score"] = ( 2 * pr_curve_counts_df["precision"] * pr_curve_counts_df["recall"] ) / (pr_curve_counts_df["precision"] + pr_curve_counts_df["recall"]) @@ -980,6 +1036,7 @@ def _calculate_pr_curves( pr_calc_df=pr_calc_df, max_examples=pr_curve_max_examples, flag_column=flag, + true_negative_datum_uids=true_negative_datum_uids, ) for _, row in pr_curve_counts_df.iterrows(): From 2f84210f57202a3e1a001260f15d7e72934fe39b Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 5 Sep 2024 01:15:37 -0600 Subject: [PATCH 3/3] remove pdb call --- core/tests/functional-tests/test_classification.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/tests/functional-tests/test_classification.py b/core/tests/functional-tests/test_classification.py index 4ccc683be..79360d2d7 100644 --- a/core/tests/functional-tests/test_classification.py +++ b/core/tests/functional-tests/test_classification.py @@ -668,9 +668,6 @@ def test_detailed_curve_examples( key, expected, ) in detailed_curve_examples_check_true_negatives_output.items(): - import pdb - - pdb.set_trace() assert ( set( output[key[0]][key[1]][key[2]]["observations"][key[3]][