Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Classification Performance #637

Merged
merged 42 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2fb63f8
fixed base performance
czaloom Jun 27, 2024
7501c28
classification fixes
czaloom Jun 27, 2024
284c16b
added results
czaloom Jun 27, 2024
e70bb62
remove benmarks
czaloom Jun 27, 2024
5846b99
pr curve performance improvements
czaloom Jun 27, 2024
07fb5a5
pr curve performance improvements
czaloom Jun 27, 2024
f7048c1
perf improvements
czaloom Jul 1, 2024
0df6437
fixed post timeouts
czaloom Jul 1, 2024
b5b6ffb
added timeout controls
czaloom Jul 1, 2024
5d3aa2d
added vacuum analyze to dataset, model finalization
czaloom Jul 1, 2024
0734f30
fixed for python 3_8
czaloom Jul 1, 2024
57d35f2
fixed args
czaloom Jul 1, 2024
f2ff1a0
fixed lack of db error in testing
czaloom Jul 1, 2024
d444d12
fixed test
czaloom Jul 1, 2024
b090cc0
merged client timeout pr
czaloom Jul 1, 2024
7506484
merged vacuum analyze pr
czaloom Jul 1, 2024
0431989
passing precommit
czaloom Jul 1, 2024
f192d9e
removing commented code
czaloom Jul 1, 2024
50282eb
fixed validate labels
czaloom Jul 1, 2024
09b3476
validate matching label keys is more straightforward
czaloom Jul 1, 2024
22f4d93
Merge branch 'czaloom-644-fix-validate_matching_label_keys' into czal…
czaloom Jul 1, 2024
20eb80b
passing python integration tests
czaloom Jul 1, 2024
b2fa72c
remove comments
czaloom Jul 1, 2024
ba9810c
updated analysis.py
czaloom Jul 1, 2024
bf88ea7
Update test_classification.py
czaloom Jul 1, 2024
5dfd9e7
Delete examples/benchmarks/analysis.py
czaloom Jul 1, 2024
a9b2036
Delete examples/benchmarks/results.json
czaloom Jul 1, 2024
7bb793d
Delete examples/benchmarks/pr-curve-oom-data.json
czaloom Jul 1, 2024
fcedd4b
Update test_classification.py
czaloom Jul 1, 2024
8299f44
added docstring
czaloom Jul 2, 2024
83d0bdf
change default to 10 for creating gts and pds
czaloom Jul 2, 2024
791f1d3
revert
czaloom Jul 2, 2024
40e75c9
Merge branch 'czaloom-add-vacuum-analyze' into czaloom-patch-581-perf…
czaloom Jul 2, 2024
2f17e3f
Merge branch 'czaloom-644-fix-validate_matching_label_keys' into czal…
czaloom Jul 2, 2024
d7dabf1
Merge branch 'czaloom-639-bug-bulk-add-error' into czaloom-patch-581-…
czaloom Jul 2, 2024
45f2e58
Merge branch 'main' into czaloom-patch-581-performance-issues
czaloom Jul 2, 2024
17c35ce
reverted test
czaloom Jul 3, 2024
fa84825
unrelated test failing due to list ordering in db
czaloom Jul 3, 2024
7ece6c5
fix typo
czaloom Jul 3, 2024
bf3c1c5
merge main
czaloom Jul 3, 2024
967e633
merged main
czaloom Jul 3, 2024
5501b26
reverted integration tests
czaloom Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions api/tests/functional-tests/backend/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def test_compute_confusion_matrix_at_grouper_key(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db=db,
Expand Down Expand Up @@ -294,12 +294,12 @@ def test_compute_confusion_matrix_at_grouper_key(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db=db,
Expand Down Expand Up @@ -445,12 +445,12 @@ def test_compute_confusion_matrix_at_grouper_key_and_filter(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db,
Expand Down Expand Up @@ -595,12 +595,12 @@ def test_compute_confusion_matrix_at_grouper_key_using_label_map(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db,
Expand Down Expand Up @@ -1256,32 +1256,44 @@ def test__compute_curves(
models.Dataset.name.label("dataset_name"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

# calculate the number of unique datums
# used to determine the number of true negatives

gt_datums = generate_query(
models.Datum.id,
models.Dataset.name,
models.Datum.uid,
db=db,
filters=groundtruth_filter,
label_source=models.GroundTruth,
).all()
pd_datums = generate_query(
models.Datum.id,
models.Dataset.name,
models.Datum.uid,
db=db,
filters=prediction_filter,
label_source=models.Prediction,
).all()
unique_datums = set(pd_datums + gt_datums)
unique_datums = {
datum_id: (dataset_name, datum_uid)
for datum_id, dataset_name, datum_uid in gt_datums
}
unique_datums.update(
{
datum_id: (dataset_name, datum_uid)
for datum_id, dataset_name, datum_uid in pd_datums
}
)

curves = _compute_curves(
db=db,
Expand Down Expand Up @@ -1368,8 +1380,8 @@ def test__compute_curves(
},
("dog", 0.05, "tn"): {"all": 1, "total": 1},
("dog", 0.8, "fn"): {
"missed_detections": 1,
"misclassifications": 1,
"missed_detections": 0,
Copy link
Contributor

@ntlind ntlind Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A prediction having a score less than the threshold is still a valid prediction though

what is the point of the score threshold in that case?

the score threshold is meant to mean "only consider predictions with a score greater than x to be valid predictions"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of a score threshold is to determine whether the prediction is positive vs negative.

Whether that prediction is correct determines its truth (True, False).

Combine these and you get TP, FP, FN and TN.

The variation of missing_detection doesnt really map well to the classification task (as compared to the obj det task) as we enforce the existence of predictions to groundtruths at ingestion time. (See validate_matching_label_keys)

This logic also applies to hallucination for FP, which, if you look at that test never gets a value counted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reached out to Matt and I think this is a definition issue. Missing detection doesnt make sense for classification. The condition of FN that you are referring to fits something closer to a "no winner" condition.

Matt suggested "No prediction" and im wondering if "Null Prediction" would make more sense.

How does all this sound to you?

"misclassifications": 2,
"total": 2,
},
# cat
Expand Down
14 changes: 12 additions & 2 deletions api/tests/functional-tests/crud/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def test_get_dataset_summary(
enums.TaskType.CLASSIFICATION,
enums.TaskType.EMPTY,
}
assert summary.datum_metadata == [

expected_datum_metadata = [
{
"width": 32,
"height": 80,
Expand All @@ -339,10 +340,19 @@ def test_get_dataset_summary(
"height": 100,
},
]
assert summary.annotation_metadata == [
for item in summary.datum_metadata:
assert item in expected_datum_metadata
for item in expected_datum_metadata:
assert item in summary.datum_metadata

expected_annotation_metadata = [
{"int_key": 1},
{
"string_key": "string_val",
"int_key": 1,
},
]
for item in summary.annotation_metadata:
assert item in expected_annotation_metadata
for item in expected_annotation_metadata:
assert item in summary.annotation_metadata
4 changes: 2 additions & 2 deletions api/valor_api/backend/core/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def _validate_evaluation_filter(
if parameters.task_type == enums.TaskType.CLASSIFICATION:
core.validate_matching_label_keys(
db=db,
dataset_names=evaluation.dataset_names,
model_name=evaluation.model_name,
label_map=parameters.label_map,
groundtruth_filter=groundtruth_filter,
prediction_filter=predictions_filter,
)


Expand Down
110 changes: 67 additions & 43 deletions api/valor_api/backend/core/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

from valor_api import api_utils, schemas
from valor_api.backend import models
from valor_api.backend.query import generate_query, generate_select
from valor_api.backend.query import generate_query
from valor_api.backend.query.types import TableTypeAlias

LabelMapType = list[list[list[str]]]


def validate_matching_label_keys(
db: Session,
dataset_names: list[str],
model_name: str,
label_map: LabelMapType | None,
prediction_filter: schemas.Filter,
groundtruth_filter: schemas.Filter,
) -> None:
"""
Validates that every datum has the same set of label keys for both ground truths and predictions. This check is only needed for classification tasks.
Expand All @@ -26,71 +26,95 @@ def validate_matching_label_keys(
----------
db : Session
The database Session to query against.
prediction_filter : schemas.Filter
The filter to be used to query predictions.
groundtruth_filter : schemas.Filter
The filter to be used to query groundtruths.
dataset_names : list[str]
The list of required datasets by name.
model_name : str
The required model by name.
label_map: LabelMapType, optional
Optional mapping of individual labels to a grouper label. Useful when you need to evaluate performance using labels that differ across datasets and models.


Raises
-------
ValueError
If the distinct ground truth label keys don't match the distinct prediction label keys for any datum.
"""

gts = generate_select(
models.Annotation.datum_id.label("datum_id"),
models.Label.key.label("label_key"),
models.Label.value.label("label_value"),
filters=groundtruth_filter,
label_source=models.GroundTruth,
).alias()

gt_label_keys_by_datum = (
gt_labels_by_datum = (
select(
gts.c.datum_id,
func.array_agg(gts.c.label_key + ", " + gts.c.label_value).label(
models.Datum.id.label("datum_id"),
func.array_agg(models.Label.key + ", " + models.Label.value).label(
"gt_labels"
),
)
.select_from(gts)
.group_by(gts.c.datum_id)
.select_from(models.Datum)
.join(
models.Dataset,
and_(
models.Dataset.id == models.Datum.dataset_id,
models.Dataset.name.in_(dataset_names),
),
)
.join(
models.Annotation,
and_(
models.Annotation.datum_id == models.Datum.id,
models.Annotation.model_id.is_(None),
),
)
.join(
models.GroundTruth,
models.GroundTruth.annotation_id == models.Annotation.id,
)
.join(models.Label, models.Label.id == models.GroundTruth.label_id)
.group_by(models.Datum.id)
.subquery()
)

preds = generate_select(
models.Annotation.datum_id.label("datum_id"),
models.Label.key.label("label_key"),
models.Label.value.label("label_value"),
filters=prediction_filter,
label_source=models.Prediction,
).alias()

preds_label_keys_by_datum = (
pred_labels_by_datum = (
select(
preds.c.datum_id,
func.array_agg(
preds.c.label_key + ", " + preds.c.label_value
).label("pred_labels"),
models.Datum.id.label("datum_id"),
func.array_agg(models.Label.key + ", " + models.Label.value).label(
"pred_labels"
),
)
.select_from(models.Datum)
.join(
models.Dataset,
and_(
models.Dataset.id == models.Datum.dataset_id,
models.Dataset.name.in_(dataset_names),
),
)
.join(
models.Annotation,
models.Annotation.datum_id == models.Datum.id,
)
.join(
models.Model,
and_(
models.Model.id == models.Annotation.model_id,
models.Model.name == model_name,
),
)
.join(
models.Prediction,
models.Prediction.annotation_id == models.Annotation.id,
)
.select_from(preds)
.group_by(preds.c.datum_id)
.join(models.Label, models.Label.id == models.Prediction.label_id)
.group_by(models.Datum.id)
.subquery()
)

joined = (
select(
preds_label_keys_by_datum.c.datum_id,
preds_label_keys_by_datum.c.pred_labels,
gt_label_keys_by_datum.c.gt_labels,
pred_labels_by_datum.c.datum_id,
pred_labels_by_datum.c.pred_labels,
gt_labels_by_datum.c.gt_labels,
)
.select_from(preds_label_keys_by_datum)
.select_from(gt_labels_by_datum)
.join(
gt_label_keys_by_datum,
gt_label_keys_by_datum.c.datum_id
== preds_label_keys_by_datum.c.datum_id,
pred_labels_by_datum,
pred_labels_by_datum.c.datum_id == gt_labels_by_datum.c.datum_id,
)
.subquery()
)
Expand Down
43 changes: 26 additions & 17 deletions api/valor_api/backend/core/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,30 @@
def _check_if_datum_has_prediction(
db: Session, datum: schemas.Datum, model_name: str, dataset_name: str
) -> None:
"""Checks to see if datum has existing annotations."""
if db.query(
select(models.Annotation.id)
.join(models.Model)
.join(models.Datum)
.join(models.Dataset)
.where(
.select_from(models.Annotation)
.join(
models.Model,
and_(
models.Dataset.name == dataset_name,
models.Datum.dataset_id == models.Dataset.id,
models.Datum.uid == datum.uid,
models.Model.id == models.Annotation.model_id,
models.Model.name == model_name,
models.Annotation.datum_id == models.Datum.id,
models.Annotation.model_id == models.Model.id,
)
),
)
.join(
models.Datum,
and_(
models.Datum.id == models.Annotation.datum_id,
models.Datum.uid == datum.uid,
),
)
.join(
models.Dataset,
and_(
models.Dataset.id == models.Datum.dataset_id,
models.Dataset.name == dataset_name,
),
)
.subquery()
).all():
Expand Down Expand Up @@ -119,15 +129,14 @@ def create_predictions(
for i, annotation in enumerate(prediction.annotations):
for label in annotation.labels:
prediction_mappings.append(
{
"annotation_id": annotation_ids_per_prediction[i],
"label_id": label_dict[(label.key, label.value)],
"score": label.score,
}
models.Prediction(
annotation_id=annotation_ids_per_prediction[i],
label_id=label_dict[(label.key, label.value)],
score=label.score,
)
)

try:
db.bulk_insert_mappings(models.Prediction, prediction_mappings)
db.add_all(prediction_mappings)
db.commit()
except IntegrityError as e:
db.rollback()
Expand Down
Loading
Loading