Skip to content

Commit

Permalink
Fixing edge cases in Query (#417)
Browse files Browse the repository at this point in the history
Co-authored-by: Nick L <[email protected]>
  • Loading branch information
czaloom and ntlind authored Feb 9, 2024
1 parent 70401e1 commit 5ae1cd7
Show file tree
Hide file tree
Showing 25 changed files with 1,907 additions and 925 deletions.
94 changes: 94 additions & 0 deletions api/tests/functional-tests/backend/core/test_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest
from sqlalchemy.orm import Session

from velour_api import enums, schemas
from velour_api.backend.core import fetch_dataset
from velour_api.backend.core.geometry import (
convert_geometry,
get_annotation_type,
)
from velour_api.crud import create_dataset, create_groundtruth


@pytest.fixture
def create_clf_dataset(db: Session, dataset_name: str):
create_dataset(db=db, dataset=schemas.Dataset(name=dataset_name))
create_groundtruth(
db=db,
groundtruth=schemas.GroundTruth(
datum=schemas.Datum(uid="uid1", dataset_name=dataset_name),
annotations=[
schemas.Annotation(
task_type=enums.TaskType.CLASSIFICATION,
labels=[schemas.Label(key="k1", value="v1")],
)
],
),
)


def test_get_annotation_type(
db: Session, dataset_name: str, create_clf_dataset
):
# tests uncovered case where `AnnotationType.NONE` is returned.
dataset = fetch_dataset(db, dataset_name)
assert (
get_annotation_type(db, enums.TaskType.CLASSIFICATION, dataset)
== enums.AnnotationType.NONE
)


def test_convert_geometry(
db: Session, dataset_name: str, dataset_model_create
):
dataset = fetch_dataset(db, dataset_name)

with pytest.raises(ValueError) as e:
convert_geometry(
db=db,
source_type=enums.AnnotationType.NONE,
target_type=enums.AnnotationType.BOX,
dataset=None,
model=None,
)
assert "Source type" in str(e)

with pytest.raises(ValueError) as e:
convert_geometry(
db=db,
source_type=enums.AnnotationType.BOX,
target_type=enums.AnnotationType.NONE,
dataset=None,
model=None,
)
assert "Target type" in str(e)

with pytest.raises(ValueError) as e:
convert_geometry(
db=db,
source_type=enums.AnnotationType.BOX,
target_type=enums.AnnotationType.RASTER,
dataset=None,
model=None,
)
assert "not capable of being converted" in str(e)

with pytest.raises(NotImplementedError) as e:
convert_geometry(
db=db,
source_type=enums.AnnotationType.MULTIPOLYGON,
target_type=enums.AnnotationType.BOX,
dataset=dataset,
model=None,
)
assert "currently unsupported" in str(e)

with pytest.raises(NotImplementedError) as e:
convert_geometry(
db=db,
source_type=enums.AnnotationType.MULTIPOLYGON,
target_type=enums.AnnotationType.POLYGON,
dataset=dataset,
model=None,
)
assert "currently unsupported" in str(e)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy
import pytest
from sqlalchemy import distinct, func
from sqlalchemy.orm import Session

from velour_api import crud, enums, schemas
Expand Down Expand Up @@ -615,23 +616,23 @@ def test_query_datasets(
model_sim,
):
# Check that passing a non-InstrumentedAttribute returns None
q = Query("not_a_valid_attribute")
assert len(q._selected) == 0
with pytest.raises(NotImplementedError):
Query("not_a_valid_attribute")

# Q: Get names for datasets where label class=cat exists in groundtruths.
f = schemas.Filter(labels=[{"class": "cat"}])
query_obj = Query(models.Dataset.name)
query_obj = Query(distinct(models.Dataset.name))
assert len(query_obj._selected) == 1

q = query_obj.filter(f).groundtruths()
dataset_names = db.query(q).distinct().all()

dataset_names = db.query(q).all()
assert len(dataset_names) == 1
assert (dset_name,) in dataset_names

# Q: Get names for datasets where label=tree exists in groundtruths
f = schemas.Filter(labels=[{"class": "tree"}])
q = Query(models.Dataset.name).filter(f).groundtruths()
dataset_names = db.query(q).distinct().all()
dataset_names = db.query(q).all()
assert len(dataset_names) == 0


Expand Down Expand Up @@ -2471,3 +2472,56 @@ def test_annotation_datetime_queries(
_test_annotation_datetime_query(db, date_key, date_metadata)
_test_annotation_datetime_query(db, time_key, time_metadata)
_test_annotation_datetime_query(db, duration_key, duration_metadata)


def test_query_expression_types(
db: Session,
model_sim,
):
# Test `distinct`
f = schemas.Filter(labels=[{"class": "cat"}])
q = Query(distinct(models.Dataset.name)).filter(f).groundtruths()
dataset_names = db.query(q).all()
assert len(dataset_names) == 1
assert (dset_name,) in dataset_names

# Test `func.count`, note this returns 10 b/c of joins.
f = schemas.Filter(labels=[{"class": "cat"}])
q = (
Query(func.count(models.Dataset.name))
.filter(f)
.groundtruths(as_subquery=False)
)
assert db.scalar(q) == 10

# Test `func.count` with nested distinct.
f = schemas.Filter(labels=[{"class": "cat"}])
q = (
Query(func.count(distinct(models.Dataset.name)))
.filter(f)
.groundtruths(as_subquery=False)
)
assert db.scalar(q) == 1

# Test distinct with nested`func.count`
# This is to test the recursive table search
# querying with this order-of-ops will fail.
f = schemas.Filter(labels=[{"class": "cat"}])
q = Query(distinct(func.count(models.Dataset.name))).filter(f)
assert q._selected == {models.Dataset}

# Test `func.count` without args, note this returns 10 b/c of joins.
f = schemas.Filter(labels=[{"class": "cat"}])
q = (
Query(func.count())
.select_from(models.Dataset)
.filter(f)
.groundtruths(as_subquery=False)
)
assert db.scalar(q) == 10

# Test nested functions
q = Query(func.max(func.ST_Area(models.Annotation.box))).groundtruths(
as_subquery=False
)
assert db.scalar(q) == 100.0
8 changes: 0 additions & 8 deletions api/tests/functional-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,6 @@ def img2_gt_mask_bytes1():
return random_mask_bytes(size=img2_size)


@pytest.fixture
def dset(db: Session) -> models.Dataset:
dset = models.Dataset(name="dset")
db.add(dset)
db.commit()
return dset


@pytest.fixture
def images() -> list[schemas.Datum]:
return [
Expand Down
Loading

0 comments on commit 5ae1cd7

Please sign in to comment.