Skip to content

Commit

Permalink
Optimize ROC AUC Computation (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Jul 11, 2024
1 parent eedc003 commit 19cd451
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 200 deletions.
69 changes: 58 additions & 11 deletions api/tests/functional-tests/backend/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,22 @@ def test_compute_roc_auc(
)
)

groundtruths = generate_select(
models.GroundTruth,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=groundtruth_filter,
label_source=models.GroundTruth,
).cte()

predictions = generate_select(
models.Prediction,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=prediction_filter,
label_source=models.Prediction,
).cte()

labels = fetch_union_of_labels(
db=db,
rhs=prediction_filter,
Expand All @@ -726,8 +742,8 @@ def test_compute_roc_auc(
assert (
_compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
groundtruths=groundtruths,
predictions=predictions,
grouper_key="animal",
grouper_mappings=grouper_mappings,
)
Expand All @@ -736,19 +752,18 @@ def test_compute_roc_auc(
assert (
_compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
groundtruths=groundtruths,
predictions=predictions,
grouper_key="color",
grouper_mappings=grouper_mappings,
)
== 0.43125
)

assert (
_compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
groundtruths=groundtruths,
predictions=predictions,
grouper_key="not a key",
grouper_mappings=grouper_mappings,
)
Expand Down Expand Up @@ -833,11 +848,27 @@ def test_compute_roc_auc_groupby_metadata(
evaluation_type=enums.TaskType.CLASSIFICATION,
)

groundtruths = generate_select(
models.GroundTruth,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=groundtruth_filter,
label_source=models.GroundTruth,
).cte()

predictions = generate_select(
models.Prediction,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=prediction_filter,
label_source=models.Prediction,
).cte()

assert (
_compute_roc_auc(
db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
groundtruths=groundtruths,
predictions=predictions,
grouper_key="animal",
grouper_mappings=grouper_mappings,
)
Expand Down Expand Up @@ -925,10 +956,26 @@ def test_compute_roc_auc_with_label_map(
evaluation_type=enums.TaskType.CLASSIFICATION,
)

groundtruths = generate_select(
models.GroundTruth,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=groundtruth_filter,
label_source=models.GroundTruth,
).cte()

predictions = generate_select(
models.Prediction,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=prediction_filter,
label_source=models.Prediction,
).cte()

roc_auc = _compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
groundtruths=groundtruths,
predictions=predictions,
grouper_key="animal",
grouper_mappings=grouper_mappings,
)
Expand Down
Loading

0 comments on commit 19cd451

Please sign in to comment.