-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add auto-annotation support to SDK and CLI
TODO: describe this
- Loading branch information
Showing
7 changed files
with
601 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from .driver import BadFunctionError, annotate_task | ||
from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import logging | ||
from typing import List, Mapping, Optional, Sequence | ||
|
||
import attrs | ||
|
||
import cvat_sdk.models as models | ||
from cvat_sdk.core import Client | ||
from cvat_sdk.datasets.task_dataset import TaskDataset | ||
|
||
from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec | ||
|
||
|
||
class BadFunctionError(Exception): | ||
pass | ||
|
||
|
||
class _AnnotationMapper: | ||
@attrs.frozen | ||
class _MappedLabel: | ||
id: int | ||
sublabel_mapping: Mapping[int, Optional[int]] | ||
expected_num_elements: int = 0 | ||
|
||
_label_mapping: Mapping[int, Optional[_MappedLabel]] | ||
|
||
def _build_mapped_label( | ||
self, fun_label: models.ILabel, ds_labels_by_name: Mapping[str, models.ILabel] | ||
) -> Optional[_MappedLabel]: | ||
if getattr(fun_label, "attributes", None): | ||
raise BadFunctionError(f"label attributes are currently not supported") | ||
|
||
ds_label = ds_labels_by_name.get(fun_label.name) | ||
if ds_label is None: | ||
self._logger.info( | ||
"label %r is not in dataset; any annotations using it will be ignored", | ||
fun_label.name, | ||
) | ||
return None | ||
|
||
sl_map = {} | ||
|
||
if getattr(fun_label, "sublabels", []): | ||
fun_label_type = getattr(fun_label, "type", "any") | ||
if fun_label_type != "skeleton": | ||
raise BadFunctionError( | ||
f"label {fun_label.name!r} with sublabels has type {fun_label_type!r} (should be 'skeleton')" | ||
) | ||
|
||
ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels} | ||
|
||
for fun_sl in fun_label.sublabels: | ||
if not hasattr(fun_sl, "id"): | ||
raise BadFunctionError( | ||
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has no ID" | ||
) | ||
|
||
if fun_sl.id in sl_map: | ||
raise BadFunctionError( | ||
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})" | ||
) | ||
|
||
ds_sl = ds_sublabels_by_name.get(fun_sl.name) | ||
if not ds_sl: | ||
self._logger.info( | ||
"sublabel %r of label %r is not in dataset; any annotations using it will be ignored", | ||
fun_sl.name, | ||
fun_label.name, | ||
) | ||
sl_map[fun_sl.id] = None | ||
continue | ||
|
||
sl_map[fun_sl.id] = ds_sl.id | ||
|
||
return self._MappedLabel( | ||
ds_label.id, sublabel_mapping=sl_map, expected_num_elements=len(ds_label.sublabels) | ||
) | ||
|
||
def __init__( | ||
self, | ||
logger: logging.Logger, | ||
fun_labels: Sequence[models.ILabel], | ||
ds_labels: Sequence[models.ILabel], | ||
) -> None: | ||
self._logger = logger | ||
|
||
ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} | ||
|
||
self._label_mapping = {} | ||
|
||
for fun_label in fun_labels: | ||
if not hasattr(fun_label, "id"): | ||
raise BadFunctionError(f"label {fun_label.name!r} has no ID") | ||
|
||
if fun_label.id in self._label_mapping: | ||
raise BadFunctionError( | ||
f"label {fun_label.name} has same ID as another label ({fun_label.id})" | ||
) | ||
|
||
self._label_mapping[fun_label.id] = self._build_mapped_label( | ||
fun_label, ds_labels_by_name | ||
) | ||
|
||
def validate_and_remap(self, shapes: List[models.LabeledShapeRequest], ds_frame: int) -> None: | ||
new_shapes = [] | ||
|
||
for shape in shapes: | ||
if hasattr(shape, "id"): | ||
raise BadFunctionError("function output shape with preset id") | ||
|
||
if hasattr(shape, "source"): | ||
raise BadFunctionError("function output shape with preset source") | ||
shape.source = "auto" | ||
|
||
if shape.frame != 0: | ||
raise BadFunctionError( | ||
f"function output shape with unexpected frame number ({shape.frame})" | ||
) | ||
|
||
shape.frame = ds_frame | ||
|
||
try: | ||
mapped_label = self._label_mapping[shape.label_id] | ||
except KeyError: | ||
raise BadFunctionError( | ||
f"function output shape with unknown label ID ({shape.label_id})" | ||
) | ||
|
||
if not mapped_label: | ||
continue | ||
|
||
shape.label_id = mapped_label.id | ||
|
||
if getattr(shape, "attributes", None): | ||
raise BadFunctionError( | ||
"function output shape with attributes, which is not yet supported" | ||
) | ||
|
||
new_shapes.append(shape) | ||
|
||
if shape.type.value == "skeleton": | ||
new_elements = [] | ||
seen_sl_ids = set() | ||
|
||
for element in shape.elements: | ||
if hasattr(element, "id"): | ||
raise BadFunctionError("function output shape element with preset id") | ||
|
||
if hasattr(element, "source"): | ||
raise BadFunctionError("function output shape element with preset source") | ||
element.source = "auto" | ||
|
||
if element.frame != 0: | ||
raise BadFunctionError( | ||
f"function output shape with unexpected frame number ({element.frame})" | ||
) | ||
|
||
element.frame = ds_frame | ||
|
||
if element.type.value != "points": | ||
raise BadFunctionError( | ||
f"function output skeleton with element type other than 'points' ({element.type.value})" | ||
) | ||
|
||
try: | ||
mapped_sl_id = mapped_label.sublabel_mapping[element.label_id] | ||
except KeyError: | ||
raise BadFunctionError( | ||
f"function output shape with unknown sublabel ID ({element.label_id})" | ||
) | ||
|
||
if not mapped_sl_id: | ||
continue | ||
|
||
if mapped_sl_id in seen_sl_ids: | ||
raise BadFunctionError( | ||
"function output skeleton with multiple elements with same sublabel" | ||
) | ||
|
||
element.label_id = mapped_sl_id | ||
|
||
seen_sl_ids.add(mapped_sl_id) | ||
|
||
new_elements.append(element) | ||
|
||
if len(new_elements) != mapped_label.expected_num_elements: | ||
# new_elements could only be shorter than expected, | ||
# because the reverse would imply that there are more distinct sublabel IDs | ||
# than are actually defined in the dataset. | ||
assert len(new_elements) < mapped_label.expected_num_elements | ||
|
||
raise BadFunctionError( | ||
f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {mapped_label.expected_num_elements})" | ||
) | ||
|
||
shape.elements[:] = new_elements | ||
else: | ||
if getattr(shape, "elements", None): | ||
raise BadFunctionError("function output non-skeleton shape with elements") | ||
|
||
shapes[:] = new_shapes | ||
|
||
|
||
def annotate_task(client: Client, task_id: int, function: DetectionFunction) -> None: | ||
dataset = TaskDataset(client, task_id) | ||
|
||
assert isinstance(function.spec, DetectionFunctionSpec) | ||
|
||
mapper = _AnnotationMapper(client.logger, function.spec.labels, dataset.labels) | ||
|
||
shapes = [] | ||
|
||
context = DetectionFunctionContext() | ||
|
||
for sample in dataset.samples: | ||
frame_shapes = function.detect(context, sample.media.load_image()) | ||
mapper.validate_and_remap(frame_shapes, sample.frame_index) | ||
shapes.extend(frame_shapes) | ||
|
||
client.logger.info("Uploading annotations to task %d", task_id) | ||
client.tasks.api.update_annotations( | ||
task_id, task_annotations_update_request=models.LabeledDataRequest(shapes=shapes) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from typing import List, Sequence | ||
|
||
import attrs | ||
import PIL.Image | ||
from typing_extensions import Protocol | ||
|
||
import cvat_sdk.models as models | ||
|
||
|
||
@attrs.frozen(kw_only=True) | ||
class DetectionFunctionSpec: | ||
labels: Sequence[models.PatchedLabelRequest] | ||
|
||
|
||
class DetectionFunctionContext: | ||
# This class exists so that the SDK can provide additional information | ||
# to the function in a backwards-compatible way. There's nothing here for now. | ||
pass | ||
|
||
|
||
class DetectionFunction(Protocol): | ||
@property | ||
def spec(self) -> DetectionFunctionSpec: | ||
... | ||
|
||
def detect( | ||
self, context: DetectionFunctionContext, image: PIL.Image.Image | ||
) -> List[models.LabeledShapeRequest]: | ||
... |
Oops, something went wrong.