Skip to content

Commit

Permalink
Metric Ingestion Patch (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Jul 11, 2024
1 parent 19cd451 commit 9b74e1f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def test_compute_confusion_matrix_at_grouper_key(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=None,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -405,6 +406,7 @@ def test_compute_confusion_matrix_at_grouper_key_and_filter(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=None,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -555,6 +557,7 @@ def test_compute_confusion_matrix_at_grouper_key_using_label_map(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=label_map,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -734,6 +737,7 @@ def test_compute_roc_auc(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=None,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -843,6 +847,7 @@ def test_compute_roc_auc_groupby_metadata(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=None,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -951,6 +956,7 @@ def test_compute_roc_auc_with_label_map(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=label_map,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -1262,6 +1268,7 @@ def test__compute_curves(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=None,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down
2 changes: 0 additions & 2 deletions api/valor_api/backend/core/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,10 @@ def _validate_evaluation_filter(

# generate filters
groundtruth_filter, prediction_filter = prepare_filter_for_evaluation(
db=db,
filters=filters,
dataset_names=evaluation.dataset_names,
model_name=evaluation.model_name,
task_type=parameters.task_type,
label_map=parameters.label_map,
)

datasets = (
Expand Down
30 changes: 6 additions & 24 deletions api/valor_api/backend/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from valor_api import enums, schemas
from valor_api.backend import core, models
from valor_api.backend.metrics.metric_utils import (
commit_results,
create_grouper_mappings,
create_metric_mappings,
get_or_create_row,
log_evaluation_duration,
log_evaluation_item_counts,
prepare_filter_for_evaluation,
Expand Down Expand Up @@ -1035,6 +1034,7 @@ def _compute_clf_metrics(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=label_map,
evaluation_type=enums.TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -1089,12 +1089,10 @@ def compute_clf_metrics(
# unpack filters and params
parameters = schemas.EvaluationParameters(**evaluation.parameters)
groundtruth_filter, prediction_filter = prepare_filter_for_evaluation(
db=db,
filters=schemas.Filter(**evaluation.filters),
dataset_names=evaluation.dataset_names,
model_name=evaluation.model_name,
task_type=parameters.task_type,
label_map=parameters.label_map,
)

log_evaluation_item_counts(
Expand All @@ -1120,36 +1118,20 @@ def compute_clf_metrics(
metrics_to_return=parameters.metrics_to_return,
)

confusion_matrices_mappings = create_metric_mappings(
# add confusion matrices to database
commit_results(
db=db,
metrics=confusion_matrices,
evaluation_id=evaluation.id,
)

for mapping in confusion_matrices_mappings:
get_or_create_row(
db,
models.ConfusionMatrix,
mapping,
)

metric_mappings = create_metric_mappings(
# add metrics to database
commit_results(
db=db,
metrics=metrics,
evaluation_id=evaluation.id,
)

for mapping in metric_mappings:
# ignore value since the other columns are unique identifiers
# and have empirically noticed value can slightly change due to floating
# point errors
get_or_create_row(
db,
models.Metric,
mapping,
columns_to_ignore=["value"],
)

log_evaluation_duration(
evaluation=evaluation,
db=db,
Expand Down
20 changes: 5 additions & 15 deletions api/valor_api/backend/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from valor_api import enums, schemas
from valor_api.backend import core, models
from valor_api.backend.metrics.metric_utils import (
commit_results,
create_grouper_mappings,
create_metric_mappings,
get_or_create_row,
log_evaluation_duration,
log_evaluation_item_counts,
prepare_filter_for_evaluation,
Expand Down Expand Up @@ -739,6 +738,7 @@ def _annotation_type_to_geojson(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=parameters.label_map,
evaluation_type=enums.TaskType.OBJECT_DETECTION,
Expand Down Expand Up @@ -1108,6 +1108,7 @@ def _annotation_type_to_geojson(
)

grouper_mappings = create_grouper_mappings(
db=db,
labels=labels,
label_map=parameters.label_map,
evaluation_type=enums.TaskType.OBJECT_DETECTION,
Expand Down Expand Up @@ -1641,12 +1642,10 @@ def compute_detection_metrics(*_, db: Session, evaluation_id: int):
# unpack filters and params
parameters = schemas.EvaluationParameters(**evaluation.parameters)
groundtruth_filter, prediction_filter = prepare_filter_for_evaluation(
db=db,
filters=schemas.Filter(**evaluation.filters),
dataset_names=evaluation.dataset_names,
model_name=evaluation.model_name,
task_type=parameters.task_type,
label_map=parameters.label_map,
)

log_evaluation_item_counts(
Expand Down Expand Up @@ -1755,22 +1754,13 @@ def compute_detection_metrics(*_, db: Session, evaluation_id: int):
target_type=target_type,
)

metric_mappings = create_metric_mappings(
# add metrics to database
commit_results(
db=db,
metrics=metrics,
evaluation_id=evaluation_id,
)

for mapping in metric_mappings:
# ignore value since the other columns are unique identifiers
# and have empircally noticed value can slightly change due to floating
# point errors

get_or_create_row(
db, models.Metric, mapping, columns_to_ignore=["value"]
)
db.commit()

log_evaluation_duration(
evaluation=evaluation,
db=db,
Expand Down
Loading

0 comments on commit 9b74e1f

Please sign in to comment.