Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Classification Performance #637

Merged
merged 42 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2fb63f8
fixed base performance
czaloom Jun 27, 2024
7501c28
classification fixes
czaloom Jun 27, 2024
284c16b
added results
czaloom Jun 27, 2024
e70bb62
remove benmarks
czaloom Jun 27, 2024
5846b99
pr curve performance improvements
czaloom Jun 27, 2024
07fb5a5
pr curve performance improvements
czaloom Jun 27, 2024
f7048c1
perf improvements
czaloom Jul 1, 2024
0df6437
fixed post timeouts
czaloom Jul 1, 2024
b5b6ffb
added timeout controls
czaloom Jul 1, 2024
5d3aa2d
added vacuum analyze to dataset, model finalization
czaloom Jul 1, 2024
0734f30
fixed for python 3_8
czaloom Jul 1, 2024
57d35f2
fixed args
czaloom Jul 1, 2024
f2ff1a0
fixed lack of db error in testing
czaloom Jul 1, 2024
d444d12
fixed test
czaloom Jul 1, 2024
b090cc0
merged client timeout pr
czaloom Jul 1, 2024
7506484
merged vacuum analyze pr
czaloom Jul 1, 2024
0431989
passing precommit
czaloom Jul 1, 2024
f192d9e
removing commented code
czaloom Jul 1, 2024
50282eb
fixed validate labels
czaloom Jul 1, 2024
09b3476
validate matching label keys is more straightforward
czaloom Jul 1, 2024
22f4d93
Merge branch 'czaloom-644-fix-validate_matching_label_keys' into czal…
czaloom Jul 1, 2024
20eb80b
passing python integration tests
czaloom Jul 1, 2024
b2fa72c
remove comments
czaloom Jul 1, 2024
ba9810c
updated analysis.py
czaloom Jul 1, 2024
bf88ea7
Update test_classification.py
czaloom Jul 1, 2024
5dfd9e7
Delete examples/benchmarks/analysis.py
czaloom Jul 1, 2024
a9b2036
Delete examples/benchmarks/results.json
czaloom Jul 1, 2024
7bb793d
Delete examples/benchmarks/pr-curve-oom-data.json
czaloom Jul 1, 2024
fcedd4b
Update test_classification.py
czaloom Jul 1, 2024
8299f44
added docstring
czaloom Jul 2, 2024
83d0bdf
change default to 10 for creating gts and pds
czaloom Jul 2, 2024
791f1d3
revert
czaloom Jul 2, 2024
40e75c9
Merge branch 'czaloom-add-vacuum-analyze' into czaloom-patch-581-perf…
czaloom Jul 2, 2024
2f17e3f
Merge branch 'czaloom-644-fix-validate_matching_label_keys' into czal…
czaloom Jul 2, 2024
d7dabf1
Merge branch 'czaloom-639-bug-bulk-add-error' into czaloom-patch-581-…
czaloom Jul 2, 2024
45f2e58
Merge branch 'main' into czaloom-patch-581-performance-issues
czaloom Jul 2, 2024
17c35ce
reverted test
czaloom Jul 3, 2024
fa84825
unrelated test failing due to list ordering in db
czaloom Jul 3, 2024
7ece6c5
fix typo
czaloom Jul 3, 2024
bf3c1c5
merge main
czaloom Jul 3, 2024
967e633
merged main
czaloom Jul 3, 2024
5501b26
reverted integration tests
czaloom Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions api/tests/functional-tests/backend/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def test_compute_confusion_matrix_at_grouper_key(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db=db,
Expand Down Expand Up @@ -294,12 +294,12 @@ def test_compute_confusion_matrix_at_grouper_key(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db=db,
Expand Down Expand Up @@ -445,12 +445,12 @@ def test_compute_confusion_matrix_at_grouper_key_and_filter(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db,
Expand Down Expand Up @@ -595,12 +595,12 @@ def test_compute_confusion_matrix_at_grouper_key_using_label_map(
models.Annotation.datum_id.label("datum_id"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

cm = _compute_confusion_matrix_at_grouper_key(
db,
Expand Down Expand Up @@ -1256,32 +1256,44 @@ def test__compute_curves(
models.Dataset.name.label("dataset_name"),
filters=gFilter,
label_source=models.GroundTruth,
).alias()
).cte()
predictions = generate_select(
models.Prediction,
models.Annotation.datum_id.label("datum_id"),
models.Dataset.name.label("dataset_name"),
filters=pFilter,
label_source=models.Prediction,
).alias()
).cte()

# calculate the number of unique datums
# used to determine the number of true negatives

gt_datums = generate_query(
models.Datum.id,
models.Dataset.name,
models.Datum.uid,
db=db,
filters=groundtruth_filter,
label_source=models.GroundTruth,
).all()
pd_datums = generate_query(
models.Datum.id,
models.Dataset.name,
models.Datum.uid,
db=db,
filters=prediction_filter,
label_source=models.Prediction,
).all()
unique_datums = set(pd_datums + gt_datums)
unique_datums = {
datum_id: (dataset_name, datum_uid)
for datum_id, dataset_name, datum_uid in gt_datums
}
unique_datums.update(
{
datum_id: (dataset_name, datum_uid)
for datum_id, dataset_name, datum_uid in pd_datums
}
)

curves = _compute_curves(
db=db,
Expand Down
43 changes: 26 additions & 17 deletions api/valor_api/backend/core/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,30 @@
def _check_if_datum_has_prediction(
db: Session, datum: schemas.Datum, model_name: str, dataset_name: str
) -> None:
"""Checks to see if datum has existing annotations."""
if db.query(
select(models.Annotation.id)
.join(models.Model)
.join(models.Datum)
.join(models.Dataset)
.where(
.select_from(models.Annotation)
.join(
models.Model,
and_(
models.Dataset.name == dataset_name,
models.Datum.dataset_id == models.Dataset.id,
models.Datum.uid == datum.uid,
models.Model.id == models.Annotation.model_id,
models.Model.name == model_name,
models.Annotation.datum_id == models.Datum.id,
models.Annotation.model_id == models.Model.id,
)
),
)
.join(
models.Datum,
and_(
models.Datum.id == models.Annotation.datum_id,
models.Datum.uid == datum.uid,
),
)
.join(
models.Dataset,
and_(
models.Dataset.id == models.Datum.dataset_id,
models.Dataset.name == dataset_name,
),
)
.subquery()
).all():
Expand Down Expand Up @@ -119,15 +129,14 @@ def create_predictions(
for i, annotation in enumerate(prediction.annotations):
for label in annotation.labels:
prediction_mappings.append(
{
"annotation_id": annotation_ids_per_prediction[i],
"label_id": label_dict[(label.key, label.value)],
"score": label.score,
}
models.Prediction(
annotation_id=annotation_ids_per_prediction[i],
label_id=label_dict[(label.key, label.value)],
score=label.score,
)
)

try:
db.bulk_insert_mappings(models.Prediction, prediction_mappings)
db.add_all(prediction_mappings)
db.commit()
except IntegrityError as e:
db.rollback()
Expand Down
Loading
Loading