Skip to content

Commit

Permalink
Object Detection Benchmarking (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Aug 16, 2024
1 parent 60622a3 commit 40f7b3b
Show file tree
Hide file tree
Showing 11 changed files with 6,300 additions and 1,702 deletions.
8 changes: 2 additions & 6 deletions api/tests/functional-tests/crud/test_create_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,15 +1064,11 @@ def test_gt_seg_as_mask_or_polys(
assert len(segs.annotations) == 2

assert segs.annotations[0].raster and segs.annotations[1].raster
decoded_mask0 = np.array(
_bytes_to_pil(b64decode(segs.annotations[0].raster.mask))
)
decoded_mask0 = segs.annotations[0].raster.array
assert decoded_mask0.shape == mask.shape
np.testing.assert_equal(decoded_mask0, mask)

decoded_mask1 = np.array(
_bytes_to_pil(b64decode(segs.annotations[1].raster.mask))
)
decoded_mask1 = segs.annotations[1].raster.array
assert decoded_mask1.shape == mask.shape
np.testing.assert_equal(decoded_mask1, mask)

Expand Down
62 changes: 43 additions & 19 deletions api/valor_api/backend/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Sequence, Tuple

from geoalchemy2 import functions as gfunc
from sqlalchemy import CTE, and_, func, or_, select
from sqlalchemy import CTE, and_, case, func, or_, select
from sqlalchemy.orm import Session, aliased

from valor_api import enums, schemas
Expand Down Expand Up @@ -458,10 +458,14 @@ def _compute_detailed_curves(
if misclassification_detected:
fn["misclassifications"].append(
(dataset_name, datum_uid, gt_geojson)
if gt_geojson is not None
else (dataset_name, datum_uid)
)
else:
fn["no_predictions"].append(
(dataset_name, datum_uid, gt_geojson)
if gt_geojson is not None
else (dataset_name, datum_uid)
)

if label_id in predictions_per_label:
Expand Down Expand Up @@ -501,10 +505,14 @@ def _compute_detailed_curves(
if misclassification_detected:
fp["misclassifications"].append(
(dataset_name, datum_uid, pd_geojson)
if pd_geojson is not None
else (dataset_name, datum_uid)
)
elif hallucination_detected:
fp["hallucinations"].append(
(dataset_name, datum_uid, pd_geojson)
if pd_geojson is not None
else (dataset_name, datum_uid)
)

# calculate metrics
Expand Down Expand Up @@ -1158,14 +1166,22 @@ def _annotation_type_to_column(
select(
gt_pd_counts.c.gt_annotation_id,
gt_pd_counts.c.pd_annotation_id,
func.coalesce(
gt_pd_counts.c.intersection
/ (
case(
(
gt_pd_counts.c.gt_count
+ gt_pd_counts.c.pd_count
- gt_pd_counts.c.intersection
== 0,
0,
),
else_=(
gt_pd_counts.c.intersection
/ (
gt_pd_counts.c.gt_count
+ gt_pd_counts.c.pd_count
- gt_pd_counts.c.intersection
)
),
0,
).label("iou"),
)
.select_from(gt_pd_counts)
Expand All @@ -1183,7 +1199,10 @@ def _annotation_type_to_column(
select(
gt_pd_pairs.c.gt_annotation_id,
gt_pd_pairs.c.pd_annotation_id,
iou_computation.label("iou"),
case(
(gfunc.ST_Area(gunion) == 0, 0),
else_=iou_computation,
).label("iou"),
)
.select_from(gt_pd_pairs)
.join(
Expand All @@ -1209,10 +1228,7 @@ def _annotation_type_to_column(
gt.c.label_id.label("gt_label_id"),
pd.c.label_id.label("pd_label_id"),
pd.c.score.label("score"),
func.coalesce(
gt_pd_ious.c.iou,
0,
).label("iou"),
gt_pd_ious.c.iou,
gt.c.geojson.label("gt_geojson"),
)
.select_from(pd)
Expand Down Expand Up @@ -1576,14 +1592,22 @@ def _annotation_type_to_column(
select(
gt_pd_counts.c.gt_annotation_id,
gt_pd_counts.c.pd_annotation_id,
func.coalesce(
gt_pd_counts.c.intersection
/ (
case(
(
gt_pd_counts.c.gt_count
+ gt_pd_counts.c.pd_count
- gt_pd_counts.c.intersection
== 0,
0,
),
else_=(
gt_pd_counts.c.intersection
/ (
gt_pd_counts.c.gt_count
+ gt_pd_counts.c.pd_count
- gt_pd_counts.c.intersection
)
),
0,
).label("iou"),
)
.select_from(gt_pd_counts)
Expand All @@ -1601,7 +1625,10 @@ def _annotation_type_to_column(
select(
gt_pd_pairs.c.gt_annotation_id,
gt_pd_pairs.c.pd_annotation_id,
iou_computation.label("iou"),
case(
(gfunc.ST_Area(gunion) == 0, 0),
else_=iou_computation,
).label("iou"),
)
.select_from(gt_pd_pairs)
.join(
Expand All @@ -1627,10 +1654,7 @@ def _annotation_type_to_column(
gt.c.label_id.label("gt_label_id"),
pd.c.label_id.label("pd_label_id"),
pd.c.score.label("score"),
func.coalesce(
gt_pd_ious.c.iou,
0,
).label("iou"),
gt_pd_ious.c.iou,
gt.c.geojson.label("gt_geojson"),
(gt.c.label_id == pd.c.label_id).label("is_match"),
)
Expand Down
24 changes: 9 additions & 15 deletions api/valor_api/schemas/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
ST_GeomFromText,
ST_MakeEmptyRaster,
ST_MapAlgebra,
ST_SnapToGrid,
)
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -1028,8 +1027,8 @@ def to_psql(self) -> ScalarSelect | bytes:
if self.geometry:
empty_raster = ST_AddBand(
ST_MakeEmptyRaster(
self.width,
self.height,
self.width, # width
self.height, # height
0, # upperleftx
0, # upperlefty
1, # scalex
Expand All @@ -1038,23 +1037,18 @@ def to_psql(self) -> ScalarSelect | bytes:
0, # skewy
0, # srid
),
"1BB",
)
geom_raster = ST_AsRaster(
ST_SnapToGrid(
ST_GeomFromText(self.geometry.to_wkt()),
1.0,
),
1.0, # scalex
1.0, # scaley
"1BB", # pixeltype
1, # value
0, # nodataval
)
return select(
ST_MapAlgebra(
empty_raster,
geom_raster,
ST_AsRaster(
ST_GeomFromText(self.geometry.to_wkt()),
empty_raster,
"1BB",
1,
0,
),
"[rast2]",
"1BB",
"UNION",
Expand Down
2 changes: 1 addition & 1 deletion client/unit-tests/schemas/test_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_multipolygon():
MultiPolygon([coords]) # type: ignore - testing
with pytest.raises(TypeError):
MultiPolygon([[coords], 123]) # type: ignore - testing
with pytest.raises(TypeError):
with pytest.raises(ValueError):
MultiPolygon([[[coords]]]) # type: ignore - testing


Expand Down
6 changes: 4 additions & 2 deletions client/valor/schemas/symbolic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,12 +769,14 @@ def __init__(

@classmethod
def __validate__(cls, value: typing.Any):
if not isinstance(value, tuple):
if not isinstance(value, (tuple, list)):
raise TypeError(
f"Expected type 'typing.Tuple[float, float]' received type '{type(value).__name__}'"
)
elif len(value) != 2:
raise ValueError("")
raise ValueError(
"A point should contain only two x-y coordinates."
)
for item in value:
if not isinstance(item, (int, float, np.floating)):
raise TypeError(
Expand Down
Loading

0 comments on commit 40f7b3b

Please sign in to comment.