Skip to content

Commit

Permalink
Improve Ingestion Performance (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Jul 9, 2024
1 parent 5605b48 commit 0788666
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 195 deletions.
2 changes: 1 addition & 1 deletion api/tests/functional-tests/backend/core/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_create_annotation_already_exists_error(
core.create_predictions(db, empty_predictions)
with pytest.raises(exceptions.DatumsAlreadyExistError):
core.create_groundtruths(db, empty_groundtruths[0:1])
with pytest.raises(exceptions.AnnotationAlreadyExistsError):
with pytest.raises(exceptions.PredictionAlreadyExistsError):
core.create_predictions(db, empty_predictions[0:1])


Expand Down
66 changes: 66 additions & 0 deletions api/tests/unit-tests/backend/core/test_annotation_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

from valor_api import enums, schemas
from valor_api.backend import models
from valor_api.backend.core.annotation import (
create_annotations,
delete_dataset_annotations,
delete_model_annotations,
)


def test_malformed_input_create_annotations():

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()], [schemas.Annotation()]],
datum_ids=[1, 2],
models_=[None],
)

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()]],
datum_ids=[1, 2],
models_=[None],
)

with pytest.raises(ValueError):
create_annotations(
db=None, # type: ignore - testing
annotations=[[schemas.Annotation()]],
datum_ids=[1],
models_=[None, None],
)


def test_malformed_input_delete_dataset_annotations():

for status in enums.TableStatus:
if status == enums.TableStatus.DELETING:
continue

dataset = models.Dataset(
name="dataset",
status=status,
)

with pytest.raises(RuntimeError):
delete_dataset_annotations(db=None, dataset=dataset) # type: ignore - testing


def test_malformed_input_delete_model_annotations():

for status in enums.ModelStatus:
if status == enums.ModelStatus.DELETING:
continue

model = models.Model(
name="model",
status=status,
)

with pytest.raises(RuntimeError):
delete_model_annotations(db=None, model=model) # type: ignore - testing
158 changes: 71 additions & 87 deletions api/valor_api/backend/core/annotation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import Any

from geoalchemy2.functions import ST_AsGeoJSON
from sqlalchemy import and_, delete, insert, select
from sqlalchemy import ScalarSelect, and_, delete, insert, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from valor_api import schemas
from valor_api.backend import models
from valor_api.backend.core.geometry import _raster_to_png_b64
from valor_api.backend.query import generate_query
from valor_api.enums import ModelStatus, TableStatus
from valor_api.enums import ModelStatus, TableStatus, TaskType


def _format_box(box: schemas.Box | None) -> str | None:
return box.to_wkt() if box else None


def _format_polygon(polygon: schemas.Polygon | None) -> str | None:
return polygon.to_wkt() if polygon else None


def _format_raster(
raster: schemas.Raster | None,
) -> ScalarSelect | bytes | None:
return raster.to_psql() if raster else None


def _create_embedding(
db: Session,
value: list[float],
) -> int:
value: list[float] | None,
) -> int | None:
"""
Creates a row in the embedding table.
Expand All @@ -31,6 +43,8 @@ def _create_embedding(
int
The row id of the embedding.
"""
if not value:
return None
try:
row = models.Embedding(value=value)
db.add(row)
Expand All @@ -41,76 +55,25 @@ def _create_embedding(
return row.id


def _create_annotation(
db: Session,
annotation: schemas.Annotation,
datum: models.Datum,
model: models.Model | None = None,
) -> dict[str, Any]:
"""
Convert an individual annotation's attributes into a dictionary for upload to psql.
Parameters
----------
annotation : schemas.Annotation
The annotation tom ap.
datum : models.Datum
The datum associated with the annotation.
model : models.Model, optional
The model associated with the annotation
Returns
----------
dict[str, Any]
A populated models.Annotation object.
"""
box = None
polygon = None
raster = None
embedding_id = None

if annotation.bounding_box:
box = annotation.bounding_box.to_wkt()
if annotation.polygon:
polygon = annotation.polygon.to_wkt()
if annotation.raster:
raster = annotation.raster.to_psql()
if annotation.embedding:
embedding_id = _create_embedding(db=db, value=annotation.embedding)

mapping = {
"datum_id": datum.id,
"model_id": model.id if model else None,
"meta": annotation.metadata,
"box": box,
"polygon": polygon,
"raster": raster,
"embedding_id": embedding_id,
"is_instance": annotation.is_instance,
"implied_task_types": annotation.implied_task_types,
}
return mapping


def create_annotations(
db: Session,
annotations: list[list[schemas.Annotation]],
datums: list[models.Datum],
models_: list[models.Model] | None | list[None] = None,
datum_ids: list[int],
models_: list[models.Model] | list[None] | None = None,
) -> list[list[models.Annotation]]:
"""
Create a list of annotations and associated labels in psql.
Parameters
----------
db
db : Session
The database Session you want to query against.
annotations
annotations : list[list[schemas.Annotation]]
The list of annotations to create.
datum
The datum associated with the annotation.
model
The model associated with the annotation.
datums : dict[tuple[int, str], int]
A mapping of (dataset_id, datum_uid) to a datum's row id.
models_: list[models.Model], optional
The model(s) associated with the annotations.
Returns
----------
Expand All @@ -122,41 +85,58 @@ def create_annotations(
exceptions.AnnotationAlreadyExistsError
If the provided datum already has existing annotations for that dataset or model.
"""
models_ = models_ or [None] * len(datums)

assert len(models_) == len(datums) == len(annotations)
# cache model ids
models_ = models_ or [None] * len(datum_ids)
model_ids = [
model.id if isinstance(model, models.Model) else model
for model in models_
]

# create annotations
annotation_mappings = [
_create_annotation(
db=db, annotation=annotation, datum=datum, model=model
)
for annotations_per_datum, datum, model in zip(
annotations, datums, models_
if not (len(model_ids) == len(datum_ids) == len(annotations)):
raise ValueError("Length mismatch between annotation elements.")

values = [
{
"datum_id": datum_id,
"model_id": model_id,
"meta": annotation.metadata,
"box": _format_box(annotation.bounding_box),
"polygon": _format_polygon(annotation.polygon),
"raster": _format_raster(annotation.raster),
"embedding_id": _create_embedding(
db=db, value=annotation.embedding
),
"is_instance": annotation.is_instance,
"implied_task_types": annotation.implied_task_types,
}
for annotations_per_datum, datum_id, model_id in zip(
annotations, datum_ids, model_ids
)
for annotation in annotations_per_datum
]

try:
insert_stmt = (
insert(models.Annotation)
.values(annotation_mappings)
.values(values)
.returning(models.Annotation.id)
)
annotation_id_list = db.execute(insert_stmt).scalars().all()
annotation_ids = list(db.execute(insert_stmt).scalars().all())
db.commit()
except IntegrityError as e:
db.rollback()
raise e

annotation_ids = []
grouped_annotation_row_ids = []
idx = 0
for annotations_per_datum in annotations:
annotation_ids.append(
annotation_id_list[idx : idx + len(annotations_per_datum)]
grouped_annotation_row_ids.append(
annotation_ids[idx : idx + len(annotations_per_datum)]
)
idx += len(annotations_per_datum)

return annotation_ids
return grouped_annotation_row_ids


def create_skipped_annotations(
Expand All @@ -177,16 +157,21 @@ def create_skipped_annotations(
The model associated with the annotation.
"""
annotation_list = [
_create_annotation(
db=db,
annotation=schemas.Annotation(),
datum=datum,
model=model,
models.Annotation(
datum_id=datum.id,
model_id=model.id if model else None,
meta=dict(),
box=None,
polygon=None,
raster=None,
embedding_id=None,
is_instance=False,
implied_task_types=[TaskType.EMPTY],
)
for datum in datums
]
try:
db.bulk_insert_mappings(models.Annotation, annotation_list)
db.add_all(annotation_list)
db.commit()
except IntegrityError as e:
db.rollback()
Expand Down Expand Up @@ -262,7 +247,6 @@ def get_annotation(
datum = db.scalar(
select(models.Datum).where(models.Datum.id == annotation.datum_id)
)

if datum is None:
raise RuntimeError(
"psql unexpectedly returned None instead of a Datum."
Expand Down
Loading

0 comments on commit 0788666

Please sign in to comment.