Skip to content

Commit

Permalink
Refactor object detection in valor_core to improve speed (#724)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntlind authored Aug 28, 2024
1 parent 64845fe commit dba5246
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 344 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ repos:
"numpy",
"pandas>=2.2.2",
"pandas-stubs", # fixes pyright issues with pandas
"pandas[performance]",
"pytest",
"python-dotenv",
"SQLAlchemy>=2.0",
Expand Down
1 change: 1 addition & 0 deletions core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"importlib_metadata; python_version < '3.8'",
"pandas>=2.2.2",
"pandas-stubs",
"pandas[performance]",
"tqdm",
"requests",
"shapely"
Expand Down
237 changes: 0 additions & 237 deletions core/tests/conftest_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3046,243 +3046,6 @@ def mammal_label_map():
}


@pytest.fixture
def evaluate_classification_with_label_maps_expected():

cat_expected_metrics = [
{"type": "Accuracy", "parameters": {"label_key": "k3"}, "value": 0.0},
{"type": "ROCAUC", "parameters": {"label_key": "k3"}, "value": 1.0},
{
"type": "Precision",
"value": 0.0,
"label": {"key": "k3", "value": "v1"},
},
{
"type": "Recall",
"value": 0.0,
"label": {"key": "k3", "value": "v1"},
},
{"type": "F1", "value": 0.0, "label": {"key": "k3", "value": "v1"}},
{
"type": "Precision",
"value": 0.0,
"label": {"key": "k3", "value": "v3"},
},
{
"type": "Recall",
"value": 0.0,
"label": {"key": "k3", "value": "v3"},
},
{"type": "F1", "value": 0.0, "label": {"key": "k3", "value": "v3"}},
{"type": "Accuracy", "parameters": {"label_key": "k5"}, "value": 0.0},
{"type": "ROCAUC", "parameters": {"label_key": "k5"}, "value": 1.0},
{
"type": "Precision",
"value": 0.0,
"label": {"key": "k5", "value": "v5"},
},
{
"type": "Recall",
"value": 0.0,
"label": {"key": "k5", "value": "v5"},
},
{"type": "F1", "value": 0.0, "label": {"key": "k5", "value": "v5"}},
{
"type": "Precision",
"value": 0.0,
"label": {"key": "k5", "value": "v1"},
},
{
"type": "Recall",
"value": 0.0,
"label": {"key": "k5", "value": "v1"},
},
{"type": "F1", "value": 0.0, "label": {"key": "k5", "value": "v1"}},
{
"type": "Accuracy",
"parameters": {"label_key": "special_class"},
"value": 1.0,
},
{
"type": "ROCAUC",
"parameters": {"label_key": "special_class"},
"value": 1.0,
},
{
"type": "Precision",
"value": 1.0,
"label": {"key": "special_class", "value": "cat_type1"},
},
{
"type": "Recall",
"value": 1.0,
"label": {"key": "special_class", "value": "cat_type1"},
},
{
"type": "F1",
"value": 1.0,
"label": {"key": "special_class", "value": "cat_type1"},
},
{"type": "Accuracy", "parameters": {"label_key": "k4"}, "value": 0.5},
{
"type": "ROCAUC",
"parameters": {
"label_key": "k4",
},
"value": 1.0,
},
{
"type": "Precision",
"value": -1.0,
"label": {"key": "k4", "value": "v5"},
},
{
"type": "Recall",
"value": -1.0,
"label": {"key": "k4", "value": "v5"},
},
{"type": "F1", "value": -1.0, "label": {"key": "k4", "value": "v5"}},
{
"type": "Precision",
"value": -1.0,
"label": {"key": "k4", "value": "v1"},
},
{
"type": "Recall",
"value": -1.0,
"label": {"key": "k4", "value": "v1"},
},
{"type": "F1", "value": -1.0, "label": {"key": "k4", "value": "v1"}},
{
"type": "Precision",
"value": 1.0,
"label": {"key": "k4", "value": "v4"},
},
{
"type": "Recall",
"value": 0.5,
"label": {"key": "k4", "value": "v4"},
},
{
"type": "F1",
"value": 0.6666666666666666,
"label": {"key": "k4", "value": "v4"},
},
{
"type": "Precision",
"value": 0.0,
"label": {"key": "k4", "value": "v8"},
},
{
"type": "Recall",
"value": 0.0,
"label": {"key": "k4", "value": "v8"},
},
{"type": "F1", "value": 0.0, "label": {"key": "k4", "value": "v8"}},
]

cat_expected_cm = [
{
"label_key": "special_class",
"entries": [
{
"prediction": "cat_type1",
"groundtruth": "cat_type1",
"count": 3,
}
],
}
# other label keys not included for testing purposes
]

pr_expected_values = {
# k3
(0, "k3", "v1", "0.1", "fp"): 1,
(0, "k3", "v1", "0.1", "tn"): 2,
(0, "k3", "v3", "0.1", "fn"): 1,
(0, "k3", "v3", "0.1", "tn"): 2,
(0, "k3", "v3", "0.1", "accuracy"): 2 / 3,
(0, "k3", "v3", "0.1", "precision"): -1,
(0, "k3", "v3", "0.1", "recall"): 0,
(0, "k3", "v3", "0.1", "f1_score"): -1,
# k4
(1, "k4", "v1", "0.1", "fp"): 1,
(1, "k4", "v1", "0.1", "tn"): 2,
(1, "k4", "v4", "0.1", "fn"): 1,
(1, "k4", "v4", "0.1", "tn"): 1,
(1, "k4", "v4", "0.1", "tp"): 1,
(1, "k4", "v4", "0.9", "tp"): 0,
(1, "k4", "v4", "0.9", "tn"): 1,
(1, "k4", "v4", "0.9", "fn"): 2,
(1, "k4", "v5", "0.1", "fp"): 1,
(1, "k4", "v5", "0.1", "tn"): 2,
(1, "k4", "v5", "0.3", "fp"): 0,
(1, "k4", "v5", "0.3", "tn"): 3,
(1, "k4", "v8", "0.1", "tn"): 2,
(1, "k4", "v8", "0.6", "fp"): 0,
(1, "k4", "v8", "0.6", "tn"): 3,
# k5
(2, "k5", "v1", "0.1", "fp"): 1,
(2, "k5", "v1", "0.1", "tn"): 2,
(2, "k5", "v5", "0.1", "fn"): 1,
(
2,
"k5",
"v5",
"0.1",
"tn",
): 2,
(2, "k5", "v1", "0.1", "accuracy"): 2 / 3,
(2, "k5", "v1", "0.1", "precision"): 0,
(2, "k5", "v1", "0.1", "recall"): -1,
(2, "k5", "v1", "0.1", "f1_score"): -1,
# special_class
(3, "special_class", "cat_type1", "0.1", "tp"): 3,
(3, "special_class", "cat_type1", "0.1", "tn"): 0,
(3, "special_class", "cat_type1", "0.95", "tp"): 3,
}

detailed_pr_expected_answers = {
# k3
(0, "v1", "0.1", "tp"): {"all": 0, "total": 0},
(0, "v1", "0.1", "fp"): {
"misclassifications": 1,
"total": 1,
},
(0, "v1", "0.1", "tn"): {"all": 2, "total": 2},
(0, "v1", "0.1", "fn"): {
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
# k4
(1, "v1", "0.1", "tp"): {"all": 0, "total": 0},
(1, "v1", "0.1", "fp"): {
"misclassifications": 1,
"total": 1,
},
(1, "v1", "0.1", "tn"): {"all": 2, "total": 2},
(1, "v1", "0.1", "fn"): {
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
(1, "v4", "0.1", "fn"): {
"no_predictions": 0,
"misclassifications": 1,
"total": 1,
},
(1, "v8", "0.1", "tn"): {"all": 2, "total": 2},
}

return (
cat_expected_metrics,
cat_expected_cm,
pr_expected_values,
detailed_pr_expected_answers,
)


@pytest.fixture
def multiclass_pr_curve_groundtruths():
return [
Expand Down
Loading

0 comments on commit dba5246

Please sign in to comment.