Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Ingestion Performance #662

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_create_annotation_already_exists_error(
core.create_predictions(db, empty_predictions)
with pytest.raises(exceptions.DatumsAlreadyExistError):
core.create_groundtruths(db, empty_groundtruths[0:1])
with pytest.raises(exceptions.AnnotationAlreadyExistsError):
with pytest.raises(exceptions.PredictionAlreadyExistsError):
core.create_predictions(db, empty_predictions[0:1])


Expand Down
66 changes: 66 additions & 0 deletions api/tests/unit-tests/backend/core/test_annotation_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

from valor_api import enums, schemas
from valor_api.backend import models
from valor_api.backend.core.annotation import (
create_annotations,
delete_dataset_annotations,
delete_model_annotations,
)


def test_malformed_input_create_annotations():

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()], [schemas.Annotation()]],
datum_ids=[1, 2],
models_=[None],
)

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()]],
datum_ids=[1, 2],
models_=[None],
)

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()]],
datum_ids=[1],
models_=[None, None],
)


def test_malformed_input_delete_dataset_annotations():

for status in enums.TableStatus:
if status == enums.TableStatus.DELETING:
continue

dataset = models.Dataset(
name="dataset",
status=status,
)

with pytest.raises(RuntimeError):
delete_dataset_annotations(db=None, dataset=dataset) # type: ignore - testing


def test_malformed_input_delete_model_annotations():

for status in enums.ModelStatus:
if status == enums.ModelStatus.DELETING:
continue

model = models.Model(
name="model",
status=status,
)

with pytest.raises(RuntimeError):
delete_model_annotations(db=None, model=model) # type: ignore - testing
158 changes: 71 additions & 87 deletions api/valor_api/backend/core/annotation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import Any

from geoalchemy2.functions import ST_AsGeoJSON
from sqlalchemy import and_, delete, insert, select
from sqlalchemy import ScalarSelect, and_, delete, insert, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from valor_api import schemas
from valor_api.backend import models
from valor_api.backend.core.geometry import _raster_to_png_b64
from valor_api.backend.query import generate_query
from valor_api.enums import ModelStatus, TableStatus
from valor_api.enums import ModelStatus, TableStatus, TaskType


def _format_box(box: schemas.Box | None) -> str | None:
return box.to_wkt() if box else None


def _format_polygon(polygon: schemas.Polygon | None) -> str | None:
return polygon.to_wkt() if polygon else None


def _format_raster(
raster: schemas.Raster | None,
) -> ScalarSelect | bytes | None:
return raster.to_psql() if raster else None


def _create_embedding(
db: Session,
value: list[float],
) -> int:
value: list[float] | None,
) -> int | None:
"""
Creates a row in the embedding table.

Expand All @@ -31,6 +43,8 @@ def _create_embedding(
int
The row id of the embedding.
"""
if not value:
return None
try:
row = models.Embedding(value=value)
db.add(row)
Expand All @@ -41,76 +55,25 @@ def _create_embedding(
return row.id


def _create_annotation(
db: Session,
annotation: schemas.Annotation,
datum: models.Datum,
model: models.Model | None = None,
) -> dict[str, Any]:
"""
Convert an individual annotation's attributes into a dictionary for upload to psql.

Parameters
----------
annotation : schemas.Annotation
The annotation tom ap.
datum : models.Datum
The datum associated with the annotation.
model : models.Model, optional
The model associated with the annotation

Returns
----------
dict[str, Any]
A populated models.Annotation object.
"""
box = None
polygon = None
raster = None
embedding_id = None

if annotation.bounding_box:
box = annotation.bounding_box.to_wkt()
if annotation.polygon:
polygon = annotation.polygon.to_wkt()
if annotation.raster:
raster = annotation.raster.to_psql()
if annotation.embedding:
embedding_id = _create_embedding(db=db, value=annotation.embedding)

mapping = {
"datum_id": datum.id,
"model_id": model.id if model else None,
"meta": annotation.metadata,
"box": box,
"polygon": polygon,
"raster": raster,
"embedding_id": embedding_id,
"is_instance": annotation.is_instance,
"implied_task_types": annotation.implied_task_types,
}
return mapping


def create_annotations(
db: Session,
annotations: list[list[schemas.Annotation]],
datums: list[models.Datum],
models_: list[models.Model] | None | list[None] = None,
datum_ids: list[int],
models_: list[models.Model] | list[None] | None = None,
) -> list[list[models.Annotation]]:
"""
Create a list of annotations and associated labels in psql.

Parameters
----------
db
db : Session
The database Session you want to query against.
annotations
annotations : list[list[schemas.Annotation]]
The list of annotations to create.
datum
The datum associated with the annotation.
model
The model associated with the annotation.
datums : dict[tuple[int, str], int]
A mapping of (dataset_id, datum_uid) to a datum's row id.
models_: list[models.Model], optional
The model(s) associated with the annotations.

Returns
----------
Expand All @@ -122,41 +85,58 @@ def create_annotations(
exceptions.AnnotationAlreadyExistsError
If the provided datum already has existing annotations for that dataset or model.
"""
models_ = models_ or [None] * len(datums)

assert len(models_) == len(datums) == len(annotations)
# cache model ids
models_ = models_ or [None] * len(datum_ids)
model_ids = [
model.id if isinstance(model, models.Model) else model
for model in models_
]

# create annotations
annotation_mappings = [
_create_annotation(
db=db, annotation=annotation, datum=datum, model=model
)
for annotations_per_datum, datum, model in zip(
annotations, datums, models_
if not (len(model_ids) == len(datum_ids) == len(annotations)):
raise ValueError("Length mismatch between annotation elements.")

values = [
{
"datum_id": datum_id,
"model_id": model_id,
"meta": annotation.metadata,
"box": _format_box(annotation.bounding_box),
"polygon": _format_polygon(annotation.polygon),
"raster": _format_raster(annotation.raster),
"embedding_id": _create_embedding(
db=db, value=annotation.embedding
),
"is_instance": annotation.is_instance,
"implied_task_types": annotation.implied_task_types,
}
for annotations_per_datum, datum_id, model_id in zip(
annotations, datum_ids, model_ids
)
for annotation in annotations_per_datum
]

try:
insert_stmt = (
insert(models.Annotation)
.values(annotation_mappings)
.values(values)
.returning(models.Annotation.id)
)
annotation_id_list = db.execute(insert_stmt).scalars().all()
annotation_ids = list(db.execute(insert_stmt).scalars().all())
db.commit()
except IntegrityError as e:
db.rollback()
raise e

annotation_ids = []
grouped_annotation_row_ids = []
idx = 0
for annotations_per_datum in annotations:
annotation_ids.append(
annotation_id_list[idx : idx + len(annotations_per_datum)]
grouped_annotation_row_ids.append(
annotation_ids[idx : idx + len(annotations_per_datum)]
)
idx += len(annotations_per_datum)

return annotation_ids
return grouped_annotation_row_ids


def create_skipped_annotations(
Expand All @@ -177,16 +157,21 @@ def create_skipped_annotations(
The model associated with the annotation.
"""
annotation_list = [
_create_annotation(
db=db,
annotation=schemas.Annotation(),
datum=datum,
model=model,
models.Annotation(
datum_id=datum.id,
model_id=model.id if model else None,
meta=dict(),
box=None,
polygon=None,
raster=None,
embedding_id=None,
is_instance=False,
implied_task_types=[TaskType.EMPTY],
)
for datum in datums
]
try:
db.bulk_insert_mappings(models.Annotation, annotation_list)
db.add_all(annotation_list)
db.commit()
except IntegrityError as e:
db.rollback()
Expand Down Expand Up @@ -262,7 +247,6 @@ def get_annotation(
datum = db.scalar(
select(models.Datum).where(models.Datum.id == annotation.datum_id)
)

if datum is None:
raise RuntimeError(
"psql unexpectedly returned None instead of a Datum."
Expand Down
Loading
Loading