Skip to content

Commit

Permalink
Add auto-annotation support to SDK and CLI
Browse files Browse the repository at this point in the history
TODO: describe this
  • Loading branch information
SpecLad committed Jul 17, 2023
1 parent e18626f commit 6c1f16b
Show file tree
Hide file tree
Showing 7 changed files with 601 additions and 0 deletions.
1 change: 1 addition & 0 deletions cvat-cli/src/cvat_cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def main(args: List[str] = None):
"upload": CLI.tasks_upload,
"export": CLI.tasks_export,
"import": CLI.tasks_import,
"auto-annotate": CLI.tasks_auto_annotate,
}
parser = make_cmdline_parser()
parsed_args = parser.parse_args(args)
Expand Down
6 changes: 6 additions & 0 deletions cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

import importlib
import json
from typing import Dict, List, Sequence, Tuple

import cvat_sdk.auto_annotation as cvataa
import tqdm
from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter
Expand Down Expand Up @@ -138,6 +140,10 @@ def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)

def tasks_auto_annotate(self, task_id: int, function_module: str) -> None:
function = importlib.import_module(function_module)
cvataa.annotate_task(self.client, task_id, function)

def _make_pbar(self, title: str = None) -> TqdmProgressReporter:
return TqdmProgressReporter(
tqdm.tqdm(unit_scale=True, unit="B", unit_divisor=1024, desc=title)
Expand Down
11 changes: 11 additions & 0 deletions cvat-cli/src/cvat_cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def make_cmdline_parser() -> argparse.ArgumentParser:
help="time interval between checks if archive processing was finished, in seconds",
)

#######################################################################
# Auto-annotate
#######################################################################
auto_annotate_task_parser = task_subparser.add_parser(
"auto-annotate", description="Automatically annotate a CVAT task."
)
auto_annotate_task_parser.add_argument("task_id", type=int, help="task ID")
auto_annotate_task_parser.add_argument(
"function_module", help="name of module to use as the function"
)

return parser


Expand Down
6 changes: 6 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from .driver import BadFunctionError, annotate_task
from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec
226 changes: 226 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import logging
from typing import List, Mapping, Optional, Sequence

import attrs

import cvat_sdk.models as models
from cvat_sdk.core import Client
from cvat_sdk.datasets.task_dataset import TaskDataset

from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec


class BadFunctionError(Exception):
pass


class _AnnotationMapper:
@attrs.frozen
class _MappedLabel:
id: int
sublabel_mapping: Mapping[int, Optional[int]]
expected_num_elements: int = 0

_label_mapping: Mapping[int, Optional[_MappedLabel]]

def _build_mapped_label(
self, fun_label: models.ILabel, ds_labels_by_name: Mapping[str, models.ILabel]
) -> Optional[_MappedLabel]:
if getattr(fun_label, "attributes", None):
raise BadFunctionError(f"label attributes are currently not supported")

ds_label = ds_labels_by_name.get(fun_label.name)
if ds_label is None:
self._logger.info(
"label %r is not in dataset; any annotations using it will be ignored",
fun_label.name,
)
return None

sl_map = {}

if getattr(fun_label, "sublabels", []):
fun_label_type = getattr(fun_label, "type", "any")
if fun_label_type != "skeleton":
raise BadFunctionError(
f"label {fun_label.name!r} with sublabels has type {fun_label_type!r} (should be 'skeleton')"
)

ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels}

for fun_sl in fun_label.sublabels:
if not hasattr(fun_sl, "id"):
raise BadFunctionError(
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has no ID"
)

if fun_sl.id in sl_map:
raise BadFunctionError(
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})"
)

ds_sl = ds_sublabels_by_name.get(fun_sl.name)
if not ds_sl:
self._logger.info(
"sublabel %r of label %r is not in dataset; any annotations using it will be ignored",
fun_sl.name,
fun_label.name,
)
sl_map[fun_sl.id] = None
continue

sl_map[fun_sl.id] = ds_sl.id

return self._MappedLabel(
ds_label.id, sublabel_mapping=sl_map, expected_num_elements=len(ds_label.sublabels)
)

def __init__(
self,
logger: logging.Logger,
fun_labels: Sequence[models.ILabel],
ds_labels: Sequence[models.ILabel],
) -> None:
self._logger = logger

ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels}

self._label_mapping = {}

for fun_label in fun_labels:
if not hasattr(fun_label, "id"):
raise BadFunctionError(f"label {fun_label.name!r} has no ID")

if fun_label.id in self._label_mapping:
raise BadFunctionError(
f"label {fun_label.name} has same ID as another label ({fun_label.id})"
)

self._label_mapping[fun_label.id] = self._build_mapped_label(
fun_label, ds_labels_by_name
)

def validate_and_remap(self, shapes: List[models.LabeledShapeRequest], ds_frame: int) -> None:
new_shapes = []

for shape in shapes:
if hasattr(shape, "id"):
raise BadFunctionError("function output shape with preset id")

if hasattr(shape, "source"):
raise BadFunctionError("function output shape with preset source")
shape.source = "auto"

if shape.frame != 0:
raise BadFunctionError(
f"function output shape with unexpected frame number ({shape.frame})"
)

shape.frame = ds_frame

try:
mapped_label = self._label_mapping[shape.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown label ID ({shape.label_id})"
)

if not mapped_label:
continue

shape.label_id = mapped_label.id

if getattr(shape, "attributes", None):
raise BadFunctionError(
"function output shape with attributes, which is not yet supported"
)

new_shapes.append(shape)

if shape.type.value == "skeleton":
new_elements = []
seen_sl_ids = set()

for element in shape.elements:
if hasattr(element, "id"):
raise BadFunctionError("function output shape element with preset id")

if hasattr(element, "source"):
raise BadFunctionError("function output shape element with preset source")
element.source = "auto"

if element.frame != 0:
raise BadFunctionError(
f"function output shape with unexpected frame number ({element.frame})"
)

element.frame = ds_frame

if element.type.value != "points":
raise BadFunctionError(
f"function output skeleton with element type other than 'points' ({element.type.value})"
)

try:
mapped_sl_id = mapped_label.sublabel_mapping[element.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown sublabel ID ({element.label_id})"
)

if not mapped_sl_id:
continue

if mapped_sl_id in seen_sl_ids:
raise BadFunctionError(
"function output skeleton with multiple elements with same sublabel"
)

element.label_id = mapped_sl_id

seen_sl_ids.add(mapped_sl_id)

new_elements.append(element)

if len(new_elements) != mapped_label.expected_num_elements:
# new_elements could only be shorter than expected,
# because the reverse would imply that there are more distinct sublabel IDs
# than are actually defined in the dataset.
assert len(new_elements) < mapped_label.expected_num_elements

raise BadFunctionError(
f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {mapped_label.expected_num_elements})"
)

shape.elements[:] = new_elements
else:
if getattr(shape, "elements", None):
raise BadFunctionError("function output non-skeleton shape with elements")

shapes[:] = new_shapes


def annotate_task(client: Client, task_id: int, function: DetectionFunction) -> None:
dataset = TaskDataset(client, task_id)

assert isinstance(function.spec, DetectionFunctionSpec)

mapper = _AnnotationMapper(client.logger, function.spec.labels, dataset.labels)

shapes = []

context = DetectionFunctionContext()

for sample in dataset.samples:
frame_shapes = function.detect(context, sample.media.load_image())
mapper.validate_and_remap(frame_shapes, sample.frame_index)
shapes.extend(frame_shapes)

client.logger.info("Uploading annotations to task %d", task_id)
client.tasks.api.update_annotations(
task_id, task_annotations_update_request=models.LabeledDataRequest(shapes=shapes)
)
33 changes: 33 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from typing import List, Sequence

import attrs
import PIL.Image
from typing_extensions import Protocol

import cvat_sdk.models as models


@attrs.frozen(kw_only=True)
class DetectionFunctionSpec:
labels: Sequence[models.PatchedLabelRequest]


class DetectionFunctionContext:
# This class exists so that the SDK can provide additional information
# to the function in a backwards-compatible way. There's nothing here for now.
pass


class DetectionFunction(Protocol):
@property
def spec(self) -> DetectionFunctionSpec:
...

def detect(
self, context: DetectionFunctionContext, image: PIL.Image.Image
) -> List[models.LabeledShapeRequest]:
...
Loading

0 comments on commit 6c1f16b

Please sign in to comment.