diff --git a/api/tests/functional-tests/backend/metrics/test_classification.py b/api/tests/functional-tests/backend/metrics/test_classification.py index 778bec3f7..c5b64db94 100644 --- a/api/tests/functional-tests/backend/metrics/test_classification.py +++ b/api/tests/functional-tests/backend/metrics/test_classification.py @@ -219,12 +219,12 @@ def test_compute_confusion_matrix_at_grouper_key( models.Annotation.datum_id.label("datum_id"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() cm = _compute_confusion_matrix_at_grouper_key( db=db, @@ -294,12 +294,12 @@ def test_compute_confusion_matrix_at_grouper_key( models.Annotation.datum_id.label("datum_id"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() cm = _compute_confusion_matrix_at_grouper_key( db=db, @@ -445,12 +445,12 @@ def test_compute_confusion_matrix_at_grouper_key_and_filter( models.Annotation.datum_id.label("datum_id"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() cm = _compute_confusion_matrix_at_grouper_key( db, @@ -595,12 +595,12 @@ def test_compute_confusion_matrix_at_grouper_key_using_label_map( models.Annotation.datum_id.label("datum_id"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() cm = _compute_confusion_matrix_at_grouper_key( db, @@ -1256,18 +1256,20 @@ def test__compute_curves( models.Dataset.name.label("dataset_name"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, + models.Annotation.datum_id.label("datum_id"), models.Dataset.name.label("dataset_name"), filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() # calculate the number of unique datums # used to determine the number of true negatives gt_datums = generate_query( + models.Datum.id, models.Dataset.name, models.Datum.uid, db=db, @@ -1275,13 +1277,23 @@ def test__compute_curves( label_source=models.GroundTruth, ).all() pd_datums = generate_query( + models.Datum.id, models.Dataset.name, models.Datum.uid, db=db, filters=prediction_filter, label_source=models.Prediction, ).all() - unique_datums = set(pd_datums + gt_datums) + unique_datums = { + datum_id: (dataset_name, datum_uid) + for datum_id, dataset_name, datum_uid in gt_datums + } + unique_datums.update( + { + datum_id: (dataset_name, datum_uid) + for datum_id, dataset_name, datum_uid in pd_datums + } + ) curves = _compute_curves( db=db, diff --git a/api/valor_api/backend/core/prediction.py b/api/valor_api/backend/core/prediction.py index e85caa398..bdcad440f 100644 --- a/api/valor_api/backend/core/prediction.py +++ b/api/valor_api/backend/core/prediction.py @@ -9,20 +9,30 @@ def _check_if_datum_has_prediction( db: Session, datum: schemas.Datum, model_name: str, dataset_name: str ) -> None: + """Checks to see if datum has existing annotations.""" if db.query( select(models.Annotation.id) - .join(models.Model) - .join(models.Datum) - .join(models.Dataset) - .where( + .select_from(models.Annotation) + .join( + models.Model, and_( - models.Dataset.name == dataset_name, - models.Datum.dataset_id == models.Dataset.id, - models.Datum.uid == datum.uid, + models.Model.id == models.Annotation.model_id, models.Model.name == model_name, - models.Annotation.datum_id == models.Datum.id, - models.Annotation.model_id == models.Model.id, - ) + ), + ) + .join( + models.Datum, + and_( + models.Datum.id == models.Annotation.datum_id, + models.Datum.uid == datum.uid, + ), + ) + .join( + models.Dataset, + and_( + models.Dataset.id == models.Datum.dataset_id, + models.Dataset.name == dataset_name, + ), ) .subquery() ).all(): @@ -119,15 +129,14 @@ def create_predictions( for i, annotation in enumerate(prediction.annotations): for label in annotation.labels: prediction_mappings.append( - { - "annotation_id": annotation_ids_per_prediction[i], - "label_id": label_dict[(label.key, label.value)], - "score": label.score, - } + models.Prediction( + annotation_id=annotation_ids_per_prediction[i], + label_id=label_dict[(label.key, label.value)], + score=label.score, + ) ) - try: - db.bulk_insert_mappings(models.Prediction, prediction_mappings) + db.add_all(prediction_mappings) db.commit() except IntegrityError as e: db.rollback() diff --git a/api/valor_api/backend/metrics/classification.py b/api/valor_api/backend/metrics/classification.py index 74ac14ffc..b91287a82 100644 --- a/api/valor_api/backend/metrics/classification.py +++ b/api/valor_api/backend/metrics/classification.py @@ -3,10 +3,9 @@ from typing import Sequence import numpy as np -from sqlalchemy import Integer, Subquery +from sqlalchemy import CTE, Integer, alias from sqlalchemy.orm import Bundle, Session -from sqlalchemy.sql import and_, case, func, select -from sqlalchemy.sql.selectable import NamedFromClause +from sqlalchemy.sql import and_, case, func, or_, select from valor_api import enums, schemas from valor_api.backend import core, models @@ -26,11 +25,11 @@ def _compute_curves( db: Session, - predictions: Subquery | NamedFromClause, - groundtruths: Subquery | NamedFromClause, + predictions: CTE, + groundtruths: CTE, grouper_key: str, grouper_mappings: dict[str, dict[str, dict]], - unique_datums: set[tuple[str, str]], + unique_datums: dict[str, tuple[str, str]], pr_curve_max_examples: int, metrics_to_return: list[enums.MetricType], ) -> list[schemas.PrecisionRecallCurve | schemas.DetailedPrecisionRecallCurve]: @@ -41,15 +40,15 @@ def _compute_curves( ---------- db: Session The database Session to query against. - prediction_filter: schemas.Filter - The filter to be used to query predictions. - groundtruth_filter: schemas.Filter - The filter to be used to query groundtruths. + predictions: CTE + A CTE defining a set of predictions. + groundtruths: CTE + A CTE defining a set of ground truths. grouper_key: str The key of the grouper used to calculate the PR curves. grouper_mappings: dict[str, dict[str, dict]] A dictionary of mappings that connect groupers to their related labels. - unique_datums: list[tuple[str, str]] + unique_datums: dict[str, tuple[str, str]] All of the unique datums associated with the ground truth and prediction filters. pr_curve_max_examples: int The maximum number of datum examples to store per true positive, false negative, etc. @@ -65,166 +64,151 @@ def _compute_curves( pr_output = defaultdict(lambda: defaultdict(dict)) detailed_pr_output = defaultdict(lambda: defaultdict(dict)) - for threshold in [x / 100 for x in range(5, 100, 5)]: - # get predictions that are above the confidence threshold - predictions_that_meet_criteria = ( - select( - models.Label.value.label("pd_label_value"), - models.Annotation.datum_id.label("datum_id"), - models.Datum.uid.label("datum_uid"), - predictions.c.dataset_name, - predictions.c.score, - ) - .select_from(predictions) - .join( - models.Annotation, - models.Annotation.id == predictions.c.annotation_id, - ) + label_keys = grouper_mappings["grouper_key_to_label_keys_mapping"][ + grouper_key + ] + + # create sets of all datums for which there is a prediction / groundtruth + # used when separating hallucinations/misclassifications/missed_detections + gt_datum_ids = { + datum_id + for datum_id in db.scalars( + select(groundtruths.c.datum_id) + .select_from(groundtruths) + .join(models.Datum, models.Datum.id == groundtruths.c.datum_id) .join( models.Label, - models.Label.id == predictions.c.label_id, + and_( + models.Label.id == groundtruths.c.label_id, + models.Label.key.in_(label_keys), + ), ) + .distinct() + ).all() + } + + pd_datum_ids_to_high_score = { + datum_id: high_score + for datum_id, high_score in db.query( + select(predictions.c.datum_id, func.max(predictions.c.score)) + .select_from(predictions) + .join(models.Datum, models.Datum.id == predictions.c.datum_id) .join( - models.Datum, - models.Datum.id == models.Annotation.datum_id, + models.Label, + and_( + models.Label.id == predictions.c.label_id, + models.Label.key.in_(label_keys), + ), ) - .where(predictions.c.score >= threshold) - .alias() - ) + .group_by(predictions.c.datum_id) + .subquery() + ).all() + } + + groundtruth_labels = alias(models.Label) + prediction_labels = alias(models.Label) - b = Bundle( - "cols", + total_query = ( + select( case( grouper_mappings["label_value_to_grouper_value"], - value=predictions_that_meet_criteria.c.pd_label_value, - ), + value=groundtruth_labels.c.value, + else_=None, + ).label("gt_label_value"), + groundtruths.c.datum_id, case( grouper_mappings["label_value_to_grouper_value"], - value=models.Label.value, - ), + value=prediction_labels.c.value, + else_=None, + ).label("pd_label_value"), + predictions.c.datum_id, + groundtruths.c.dataset_name, + models.Datum.uid.label("datum_uid"), + predictions.c.score, ) - - total_query = ( - select( - b, - predictions_that_meet_criteria.c.datum_id, - predictions_that_meet_criteria.c.datum_uid, - predictions_that_meet_criteria.c.dataset_name, - groundtruths.c.datum_id, - models.Datum.uid, - groundtruths.c.dataset_name, - ) - .select_from(groundtruths) - .join( - predictions_that_meet_criteria, - groundtruths.c.datum_id - == predictions_that_meet_criteria.c.datum_id, - isouter=True, - ) - .join( - models.Label, - models.Label.id == groundtruths.c.label_id, - ) - .join( - models.Datum, + .select_from(groundtruths) + .join( + predictions, + predictions.c.datum_id == groundtruths.c.datum_id, + full=True, + ) + .join( + models.Datum, + or_( models.Datum.id == groundtruths.c.datum_id, - ) - .group_by( - b, # type: ignore - SQLAlchemy Bundle typing issue - predictions_that_meet_criteria.c.datum_id, - predictions_that_meet_criteria.c.datum_uid, - predictions_that_meet_criteria.c.dataset_name, - groundtruths.c.datum_id, - models.Datum.uid, - groundtruths.c.dataset_name, - ) + models.Datum.id == predictions.c.datum_id, + ), + ) + .join( + groundtruth_labels, + and_( + groundtruth_labels.c.id == groundtruths.c.label_id, + groundtruth_labels.c.key.in_(label_keys), + ), ) - res = list(db.execute(total_query).all()) - # handle edge case where there were multiple prediction labels for a single datum - # first we sort, then we only increment fn below if the datum_id wasn't counted as a tp or fp - res.sort( - key=lambda x: ((x[1] is None, x[0][0] != x[0][1], x[1], x[2])) + .join( + prediction_labels, + and_( + prediction_labels.c.id == predictions.c.label_id, + prediction_labels.c.key.in_(label_keys), + ), ) + .subquery() + ) - # create sets of all datums for which there is a prediction / groundtruth - # used when separating hallucinations/misclassifications/missed_detections - gt_datums = set() - pd_datums = set() - - for row in res: - (pd_datum_uid, pd_dataset_name, gt_datum_uid, gt_dataset_name,) = ( - row[2], - row[3], - row[5], - row[6], - ) - gt_datums.add((gt_dataset_name, gt_datum_uid)) - pd_datums.add((pd_dataset_name, pd_datum_uid)) + sorted_query = select(total_query).order_by( + total_query.c.gt_label_value != total_query.c.pd_label_value, + -total_query.c.score, + ) + res = db.query(sorted_query.subquery()).all() + + for threshold in [x / 100 for x in range(5, 100, 5)]: for grouper_value in grouper_mappings["grouper_key_to_labels_mapping"][ grouper_key ].keys(): - tp, tn, fp, fn = [], [], defaultdict(list), defaultdict(list) - seen_datums = set() + tp, tn, fp, fn = set(), set(), defaultdict(set), defaultdict(set) + seen_datum_ids = set() for row in res: ( + groundtruth_label, + gt_datum_id, predicted_label, - actual_label, - pd_datum_uid, - pd_dataset_name, - gt_datum_uid, - gt_dataset_name, - ) = ( - row[0][0], - row[0][1], - row[2], - row[3], - row[5], - row[6], - ) - - if predicted_label == grouper_value == actual_label: - tp += [(pd_dataset_name, pd_datum_uid)] - seen_datums.add(gt_datum_uid) - elif predicted_label == grouper_value: + pd_datum_id, + score, + ) = (row[0], row[1], row[2], row[3], row[6]) + + if ( + groundtruth_label == grouper_value + and predicted_label == grouper_value + and score >= threshold + ): + tp.add(pd_datum_id) + seen_datum_ids.add(pd_datum_id) + elif predicted_label == grouper_value and score >= threshold: # if there was a groundtruth for a given datum, then it was a misclassification - if (pd_dataset_name, pd_datum_uid) in gt_datums: - fp["misclassifications"].append( - (pd_dataset_name, pd_datum_uid) - ) + if pd_datum_id in gt_datum_ids: + fp["misclassifications"].add(pd_datum_id) else: - fp["hallucinations"].append( - (pd_dataset_name, pd_datum_uid) - ) - seen_datums.add(gt_datum_uid) + fp["hallucinations"].add(pd_datum_id) + seen_datum_ids.add(pd_datum_id) elif ( - actual_label == grouper_value - and gt_datum_uid not in seen_datums + groundtruth_label == grouper_value + and gt_datum_id not in seen_datum_ids ): # if there was a prediction for a given datum, then it was a misclassification - if (gt_dataset_name, gt_datum_uid) in pd_datums: - fn["misclassifications"].append( - (gt_dataset_name, gt_datum_uid) - ) + if ( + gt_datum_id in pd_datum_ids_to_high_score + and pd_datum_ids_to_high_score[gt_datum_id] + >= threshold + ): + fn["misclassifications"].add(gt_datum_id) else: - fn["missed_detections"].append( - (gt_dataset_name, gt_datum_uid) - ) - seen_datums.add(gt_datum_uid) - - # calculate metrics - tn = [ - datum_uid_pair - for datum_uid_pair in unique_datums - if datum_uid_pair - not in tp - + fp["hallucinations"] - + fp["misclassifications"] - + fn["misclassifications"] - + fn["missed_detections"] - and None not in datum_uid_pair - ] + fn["missed_detections"].add(gt_datum_id) + seen_datum_ids.add(gt_datum_id) + + tn = set(unique_datums.keys()) - seen_datum_ids tp_cnt, fp_cnt, fn_cnt, tn_cnt = ( len(tp), len(fp["hallucinations"]) + len(fp["misclassifications"]), @@ -264,6 +248,16 @@ def _compute_curves( enums.MetricType.DetailedPrecisionRecallCurve in metrics_to_return ): + tp = [unique_datums[datum_id] for datum_id in tp] + fp = { + key: [unique_datums[datum_id] for datum_id in fp[key]] + for key in fp + } + tn = [unique_datums[datum_id] for datum_id in tn] + fn = { + key: [unique_datums[datum_id] for datum_id in fn[key]] + for key in fn + } detailed_pr_output[grouper_value][threshold] = { "tp": { @@ -396,44 +390,56 @@ def _compute_binary_roc_auc( The binary ROC AUC score. """ # query to get the datum_ids and label values of groundtruths that have the given label key - gts_filter = groundtruth_filter.model_copy() - gts_filter.labels = schemas.LogicalFunction.and_( - gts_filter.labels, - schemas.Condition( - lhs=schemas.Symbol(name=schemas.SupportedSymbol.LABEL_KEY), - rhs=schemas.Value.infer(label.key), - op=schemas.FilterOperator.EQ, - ), - ) - gts_query = generate_select( - models.Annotation.datum_id.label("datum_id"), - models.Label.value.label("label_value"), - filters=gts_filter, + + filtered_groundtruths = generate_select( + models.GroundTruth, + filters=groundtruth_filter, label_source=models.GroundTruth, + ).subquery() + gts_query = ( + select( + models.Annotation.datum_id.label("datum_id"), + models.Label.value.label("label_value"), + ) + .select_from(models.Annotation) + .join( + filtered_groundtruths, + filtered_groundtruths.c.annotation_id == models.Annotation.id, + ) + .join( + models.Label, + and_( + models.Label.id == filtered_groundtruths.c.label_id, + models.Label.key == label.key, + ), + ) ).subquery("groundtruth_subquery") # get the prediction scores for the given label (key and value) - preds_filter = prediction_filter.model_copy() - preds_filter.labels = schemas.LogicalFunction.and_( - preds_filter.labels, - schemas.Condition( - lhs=schemas.Symbol(name=schemas.SupportedSymbol.LABEL_KEY), - rhs=schemas.Value.infer(label.key), - op=schemas.FilterOperator.EQ, - ), - schemas.Condition( - lhs=schemas.Symbol(name=schemas.SupportedSymbol.LABEL_VALUE), - rhs=schemas.Value.infer(label.value), - op=schemas.FilterOperator.EQ, - ), - ) - - preds_query = generate_select( - models.Annotation.datum_id.label("datum_id"), - models.Prediction.score.label("score"), - models.Label.value.label("label_value"), - filters=preds_filter, + filtered_predictions = generate_select( + models.Prediction, + filters=prediction_filter, label_source=models.Prediction, + ).subquery() + preds_query = ( + select( + models.Annotation.datum_id.label("datum_id"), + filtered_predictions.c.score.label("score"), + models.Label.value.label("label_value"), + ) + .select_from(models.Annotation) + .join( + filtered_predictions, + filtered_predictions.c.annotation_id == models.Annotation.id, + ) + .join( + models.Label, + and_( + models.Label.id == filtered_predictions.c.label_id, + models.Label.key == label.key, + models.Label.value == label.value, + ), + ) ).subquery("prediction_subquery") # number of ground truth labels that match the given label value @@ -604,8 +610,8 @@ def _compute_roc_auc( def _compute_confusion_matrix_at_grouper_key( db: Session, - predictions: Subquery | NamedFromClause, - groundtruths: Subquery | NamedFromClause, + predictions: CTE, + groundtruths: CTE, grouper_key: str, grouper_mappings: dict[str, dict[str, dict]], ) -> schemas.ConfusionMatrix | None: @@ -616,10 +622,10 @@ def _compute_confusion_matrix_at_grouper_key( ---------- db : Session The database Session to query against. - prediction_filter : schemas.Filter - The filter to be used to query predictions. - groundtruth_filter : schemas.Filter - The filter to be used to query groundtruths. + predictions: CTE + A CTE defining a set of predictions. + groundtruths: CTE + A CTE defining a set of ground truths. grouper_key: str The key of the grouper used to calculate the confusion matrix. grouper_mappings: dict[str, dict[str, dict]] @@ -644,7 +650,7 @@ def _compute_confusion_matrix_at_grouper_key( models.Annotation.id == predictions.c.annotation_id, ) .group_by(models.Annotation.datum_id) - .alias() + .subquery() ) # 2. Remove duplicate scores per datum @@ -669,7 +675,7 @@ def _compute_confusion_matrix_at_grouper_key( ), ) .group_by(models.Annotation.datum_id) - .alias() + .subquery() ) # 3. Get labels for hard predictions, organize per datum @@ -687,7 +693,7 @@ def _compute_confusion_matrix_at_grouper_key( models.Label, models.Label.id == models.Prediction.label_id, ) - .alias() + .subquery() ) # 4. Link each label value to its corresponding grouper value @@ -881,7 +887,7 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( models.Dataset.name.label("dataset_name"), filters=gFilter, label_source=models.GroundTruth, - ).alias() + ).cte() predictions = generate_select( models.Prediction, @@ -889,7 +895,7 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( models.Dataset.name.label("dataset_name"), filters=pFilter, label_source=models.Prediction, - ).alias() + ).cte() confusion_matrix = _compute_confusion_matrix_at_grouper_key( db=db, @@ -934,6 +940,7 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( # calculate the number of unique datums # used to determine the number of true negatives gt_datums = generate_query( + models.Datum.id, models.Dataset.name, models.Datum.uid, db=db, @@ -941,13 +948,23 @@ def _compute_confusion_matrix_and_metrics_at_grouper_key( label_source=models.GroundTruth, ).all() pd_datums = generate_query( + models.Datum.id, models.Dataset.name, models.Datum.uid, db=db, filters=prediction_filter, label_source=models.Prediction, ).all() - unique_datums = set(gt_datums + pd_datums) + unique_datums = { + datum_id: (dataset_name, datum_uid) + for datum_id, dataset_name, datum_uid in gt_datums + } + unique_datums.update( + { + datum_id: (dataset_name, datum_uid) + for datum_id, dataset_name, datum_uid in pd_datums + } + ) pr_curves = _compute_curves( db=db, diff --git a/api/valor_api/backend/metrics/metric_utils.py b/api/valor_api/backend/metrics/metric_utils.py index aaaace6b1..6ea9a9e06 100644 --- a/api/valor_api/backend/metrics/metric_utils.py +++ b/api/valor_api/backend/metrics/metric_utils.py @@ -83,7 +83,7 @@ def _create_classification_grouper_mappings( """Create grouper mappings for use when evaluating classifications.""" # define mappers to connect groupers with labels - label_value_to_grouper_value = {} + label_value_to_grouper_value = dict() grouper_key_to_labels_mapping = defaultdict(lambda: defaultdict(set)) grouper_key_to_label_keys_mapping = defaultdict(set) diff --git a/api/valor_api/crud/_create.py b/api/valor_api/crud/_create.py index 50733fbaa..5a1db6147 100644 --- a/api/valor_api/crud/_create.py +++ b/api/valor_api/crud/_create.py @@ -133,7 +133,6 @@ def create_or_get_evaluations( ) case _: raise RuntimeError - if task_handler: task_handler.add_task( compute_func,