Skip to content

Commit

Permalink
Update PR Curve definitions for Classification (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Jul 3, 2024
1 parent f539c0d commit da65b83
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1361,39 +1361,36 @@ def test__compute_curves(
# bird
("bird", 0.05, "tp"): {"all": 3, "total": 3},
("bird", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 1,
"total": 1,
},
("bird", 0.05, "tn"): {"all": 2, "total": 2},
("bird", 0.05, "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
# dog
("dog", 0.05, "tp"): {"all": 2, "total": 2},
("dog", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 3,
"total": 3,
},
("dog", 0.05, "tn"): {"all": 1, "total": 1},
("dog", 0.8, "fn"): {
"missed_detections": 1,
"no_predictions": 1,
"misclassifications": 1,
"total": 2,
},
# cat
("cat", 0.05, "tp"): {"all": 1, "total": 1},
("cat", 0.05, "fp"): {
"hallucinations": 0,
"misclassifications": 5,
"total": 5,
},
("cat", 0.05, "tn"): {"all": 0, "total": 0},
("cat", 0.8, "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
Expand Down
28 changes: 14 additions & 14 deletions api/tests/functional-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,20 +826,20 @@ def test__compute_detailed_curves(db: Session):
# (class, 4)
("4", 0.05, "tp"): {"all": 2, "total": 2},
("4", 0.05, "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
# (class, 2)
("2", 0.05, "tp"): {"all": 1, "total": 1},
("2", 0.05, "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 1,
"total": 1,
},
("2", 0.75, "tp"): {"all": 0, "total": 0},
("2", 0.75, "fn"): {
"missed_detections": 2,
"no_predictions": 2,
"misclassifications": 0,
"total": 2,
},
Expand All @@ -855,14 +855,14 @@ def test__compute_detailed_curves(db: Session):
# (class, 1)
("1", 0.05, "tp"): {"all": 1, "total": 1},
("1", 0.8, "fn"): {
"missed_detections": 1,
"no_predictions": 1,
"misclassifications": 0,
"total": 1,
},
# (class, 0)
("0", 0.05, "tp"): {"all": 5, "total": 5},
("0", 0.95, "fn"): {
"missed_detections": 4,
"no_predictions": 4,
"misclassifications": 0,
"total": 4,
},
Expand Down Expand Up @@ -891,7 +891,7 @@ def test__compute_detailed_curves(db: Session):
# spot check number of examples
assert (
len(
output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict
output[1].value["0"][0.95]["fn"]["observations"]["no_predictions"][ # type: ignore - we know this element is a dict
"examples"
]
)
Expand Down Expand Up @@ -956,20 +956,20 @@ def test__compute_detailed_curves(db: Session):
# (class, 4)
("4", 0.05, "tp"): {"all": 0, "total": 0},
("4", 0.05, "fn"): {
"missed_detections": 2, # below IOU threshold of .9
"no_predictions": 2, # below IOU threshold of .9
"misclassifications": 0,
"total": 2,
},
# (class, 2)
("2", 0.05, "tp"): {"all": 1, "total": 1},
("2", 0.05, "fn"): {
"missed_detections": 1,
"no_predictions": 1,
"misclassifications": 0,
"total": 1,
},
("2", 0.75, "tp"): {"all": 0, "total": 0},
("2", 0.75, "fn"): {
"missed_detections": 2,
"no_predictions": 2,
"misclassifications": 0,
"total": 2,
},
Expand All @@ -985,14 +985,14 @@ def test__compute_detailed_curves(db: Session):
# (class, 1)
("1", 0.05, "tp"): {"all": 0, "total": 0},
("1", 0.8, "fn"): {
"missed_detections": 1,
"no_predictions": 1,
"misclassifications": 0,
"total": 1,
},
# (class, 0)
("0", 0.05, "tp"): {"all": 1, "total": 1},
("0", 0.95, "fn"): {
"missed_detections": 5,
"no_predictions": 5,
"misclassifications": 0,
"total": 5,
},
Expand Down Expand Up @@ -1021,7 +1021,7 @@ def test__compute_detailed_curves(db: Session):
# spot check number of examples
assert (
len(
second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict
second_output[1].value["0"][0.95]["fn"]["observations"]["no_predictions"][ # type: ignore - we know this element is a dict
"examples"
]
)
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def test__compute_detailed_curves(db: Session):
# spot check number of examples
assert (
len(
second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict
second_output[1].value["0"][0.95]["fn"]["observations"]["no_predictions"][ # type: ignore - we know this element is a dict
"examples"
]
)
Expand Down Expand Up @@ -1117,7 +1117,7 @@ def test__compute_detailed_curves(db: Session):
# spot check number of examples
assert (
len(
second_output[1].value["0"][0.95]["fn"]["observations"]["missed_detections"][ # type: ignore - we know this element is a dict
second_output[1].value["0"][0.95]["fn"]["observations"]["no_predictions"][ # type: ignore - we know this element is a dict
"examples"
]
)
Expand Down
52 changes: 10 additions & 42 deletions api/valor_api/backend/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,7 @@ def _compute_curves(
]

# create sets of all datums for which there is a prediction / groundtruth
# used when separating hallucinations/misclassifications/missed_detections
gt_datum_ids = {
datum_id
for datum_id in db.scalars(
select(groundtruths.c.datum_id)
.select_from(groundtruths)
.join(models.Datum, models.Datum.id == groundtruths.c.datum_id)
.join(
models.Label,
and_(
models.Label.id == groundtruths.c.label_id,
models.Label.key.in_(label_keys),
),
)
.distinct()
).all()
}

# used when separating misclassifications/no_predictions
pd_datum_ids_to_high_score = {
datum_id: high_score
for datum_id, high_score in db.query(
Expand Down Expand Up @@ -188,10 +171,7 @@ def _compute_curves(
seen_datum_ids.add(pd_datum_id)
elif predicted_label == grouper_value and score >= threshold:
# if there was a groundtruth for a given datum, then it was a misclassification
if pd_datum_id in gt_datum_ids:
fp["misclassifications"].add(pd_datum_id)
else:
fp["hallucinations"].add(pd_datum_id)
fp["misclassifications"].add(pd_datum_id)
seen_datum_ids.add(pd_datum_id)
elif (
groundtruth_label == grouper_value
Expand All @@ -205,14 +185,14 @@ def _compute_curves(
):
fn["misclassifications"].add(gt_datum_id)
else:
fn["missed_detections"].add(gt_datum_id)
fn["no_predictions"].add(gt_datum_id)
seen_datum_ids.add(gt_datum_id)

tn = set(unique_datums.keys()) - seen_datum_ids
tp_cnt, fp_cnt, fn_cnt, tn_cnt = (
len(tp),
len(fp["hallucinations"]) + len(fp["misclassifications"]),
len(fn["missed_detections"]) + len(fn["misclassifications"]),
len(fp["misclassifications"]),
len(fn["no_predictions"]) + len(fn["misclassifications"]),
len(tn),
)

Expand Down Expand Up @@ -301,16 +281,16 @@ def _compute_curves(
else fn["misclassifications"]
),
},
"missed_detections": {
"count": len(fn["missed_detections"]),
"no_predictions": {
"count": len(fn["no_predictions"]),
"examples": (
random.sample(
fn["missed_detections"],
fn["no_predictions"],
pr_curve_max_examples,
)
if len(fn["missed_detections"])
if len(fn["no_predictions"])
>= pr_curve_max_examples
else fn["missed_detections"]
else fn["no_predictions"]
),
},
},
Expand All @@ -330,18 +310,6 @@ def _compute_curves(
else fp["misclassifications"]
),
},
"hallucinations": {
"count": len(fp["hallucinations"]),
"examples": (
random.sample(
fp["hallucinations"],
pr_curve_max_examples,
)
if len(fp["hallucinations"])
>= pr_curve_max_examples
else fp["hallucinations"]
),
},
},
},
}
Expand Down
16 changes: 8 additions & 8 deletions api/valor_api/backend/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _compute_detailed_curves(

# transform sorted_ranked_pairs into two sets (groundtruths and predictions)
# we'll use these dictionaries to look up the IOU overlap between specific groundtruths and predictions
# to separate misclassifications from hallucinations/missed_detections
# to separate misclassifications
pd_datums = defaultdict(lambda: defaultdict(list))
gt_datums = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -466,7 +466,7 @@ def _compute_detailed_curves(
(dataset_name, datum_uid, gt_geojson)
)
else:
fn["missed_detections"].append(
fn["no_predictions"].append(
(dataset_name, datum_uid, gt_geojson)
)

Expand Down Expand Up @@ -521,7 +521,7 @@ def _compute_detailed_curves(
tp_cnt, fp_cnt, fn_cnt = (
len(tp),
len(fp["hallucinations"]) + len(fp["misclassifications"]),
len(fn["missed_detections"]) + len(fn["misclassifications"]),
len(fn["no_predictions"]) + len(fn["misclassifications"]),
)
precision = (
tp_cnt / (tp_cnt + fp_cnt) if (tp_cnt + fp_cnt) > 0 else -1
Expand Down Expand Up @@ -575,16 +575,16 @@ def _compute_detailed_curves(
else fn["misclassifications"]
),
},
"missed_detections": {
"count": len(fn["missed_detections"]),
"no_predictions": {
"count": len(fn["no_predictions"]),
"examples": (
random.sample(
fn["missed_detections"],
fn["no_predictions"],
pr_curve_max_examples,
)
if len(fn["missed_detections"])
if len(fn["no_predictions"])
>= pr_curve_max_examples
else fn["missed_detections"]
else fn["no_predictions"]
),
},
},
Expand Down
19 changes: 15 additions & 4 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,23 @@ The `PrecisionRecallCurve` values differ from the precision-recall curves used t
Valor also includes a more detailed version of `PrecisionRecallCurve` which can be useful for debugging your model's false positives and false negatives. When calculating `DetailedPrecisionCurve`, Valor will classify false positives as either `hallucinations` or `misclassifications` and your false negatives as either `missed_detections` or `misclassifications` using the following logic:

#### Classification Tasks
- A **false positive** is a `misclassification` if there is a qualified prediction (with `score >= score_threshold`) with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect. For example: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another label value (e.g., `Label(key='animal', value='cat')`) on that datum, we'd say it's a `misclassification` since the key was correct but the value was not. Any false positives that do not meet this criteria are considered to be `hallucinations`.
- Similarly, a **false negative** is a `misclassification` if there is a prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect. Any false negatives that do not meet this criteria are considered to be `missed_detections`.
- A **false positive** occurs when there is a qualified prediction (with `score >= score_threshold`) with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect.
- **Example**: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another label value (e.g., `Label(key='animal', value='cat')`) on that datum, we'd say it's a `misclassification` since the key was correct but the value was not.
- Similarly, a **false negative** occurs when there is a prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect.
- Stratifications of False Negatives:
- `misclassification`: Occurs when a different label value passes the score threshold.
- `no_predictions`: Occurs when no label passes the score threshold.

#### Object Detection Tasks
- A **false positive** is a `misclassification` if a) there is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect, and b) the qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`. For example: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another bounding box directly over that same object (e.g., `Label(key='animal', value='cat')`), we'd say it's a `misclassification`. Any false positives that do not meet this criteria are considered to be `hallucinations`.
- A **false negative** is determined to be a `misclassification` if the two criteria above are met: a) there is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect, and b) the qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`. Any false negatives that do not meet this criteria are considered to be `missed_detections`.
- A **false positive** is a `misclassification` if the following conditions are met:
1. There is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect
2. The qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`.
- A **false positive** that does not meet the `misclassification` criteria is considered to be a part of the `hallucinations` set.
- A **false negative** is determined to be a `misclassification` if the following criteria are met:
1. There is a qualified prediction with the same `Label.key` as the groundtruth on the datum, but the `Label.value` is incorrect.
2. The qualified prediction and groundtruth have an IOU >= `pr_curve_iou_threshold`.
- For a **false negative** that does not meet this criteria, we consider it to have `no_predictions`.
- **Example**: if there's a photo with one groundtruth label on it (e.g., `Label(key='animal', value='dog')`), and we predicted another bounding box directly over that same object (e.g., `Label(key='animal', value='cat')`), we'd say it's a `misclassification`.

The `DetailedPrecisionRecallOutput` also includes up to `n` examples of each type of error, where `n` is set using `pr_curve_max_examples`. An example output is as follows:

Expand Down
8 changes: 3 additions & 5 deletions integration_tests/client/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,31 +1111,29 @@ def test_evaluate_classification_with_label_maps(
# k3
(0, "v1", "0.1", "tp"): {"all": 0, "total": 0},
(0, "v1", "0.1", "fp"): {
"hallucinations": 0,
"misclassifications": 1,
"total": 1,
},
(0, "v1", "0.1", "tn"): {"all": 2, "total": 2},
(0, "v1", "0.1", "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
# k4
(1, "v1", "0.1", "tp"): {"all": 0, "total": 0},
(1, "v1", "0.1", "fp"): {
"hallucinations": 0,
"misclassifications": 1,
"total": 1,
},
(1, "v1", "0.1", "tn"): {"all": 2, "total": 2},
(1, "v1", "0.1", "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 0,
"total": 0,
},
(1, "v4", "0.1", "fn"): {
"missed_detections": 0,
"no_predictions": 0,
"misclassifications": 1,
"total": 1,
},
Expand Down
Loading

0 comments on commit da65b83

Please sign in to comment.