diff --git a/CHANGELOG.md b/CHANGELOG.md index ba25166343e..cecc6f5074c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 functionality for automatically annotating a task by running a user-provided function on the local machine, and a corresponding CLI command (`auto-annotate`) - () + (, + ) - Cached frames indication on the interface () ### Changed diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index d0417944aa6..114e5bed894 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -8,7 +8,7 @@ import importlib.util import json from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import cvat_sdk.auto_annotation as cvataa from cvat_sdk import Client, models @@ -151,6 +151,7 @@ def tasks_auto_annotate( *, function_module: Optional[str] = None, function_file: Optional[Path] = None, + function_parameters: Dict[str, Any], clear_existing: bool = False, allow_unmatched_labels: bool = False, ) -> None: @@ -163,6 +164,13 @@ def tasks_auto_annotate( else: assert False, "function identification arguments missing" + if hasattr(function, "create"): + # this is actually a function factory + function = function.create(**function_parameters) + else: + if function_parameters: + raise TypeError("function takes no parameters") + cvataa.annotate_task( self.client, task_id, diff --git a/cvat-cli/src/cvat_cli/parser.py b/cvat-cli/src/cvat_cli/parser.py index c1a7e6c3abd..f03a52f9b41 100644 --- a/cvat-cli/src/cvat_cli/parser.py +++ b/cvat-cli/src/cvat_cli/parser.py @@ -11,6 +11,7 @@ import textwrap from distutils.util import strtobool from pathlib import Path +from typing import Any, Tuple from cvat_sdk.core.proxies.tasks import ResourceType @@ -41,6 +42,40 @@ def parse_resource_type(s: str) -> ResourceType: return s +def parse_function_parameter(s: str) -> Tuple[str, Any]: + key, sep, type_and_value = s.partition("=") + + if not sep: + raise argparse.ArgumentTypeError("parameter value not specified") + + type_, sep, value = type_and_value.partition(":") + + if not sep: + raise argparse.ArgumentTypeError("parameter type not specified") + + if type_ == "int": + value = int(value) + elif type_ == "float": + value = float(value) + elif type_ == "str": + pass + elif type_ == "bool": + value = bool(strtobool(value)) + else: + raise argparse.ArgumentTypeError(f"unsupported parameter type {type_!r}") + + return (key, value) + + +class BuildDictAction(argparse.Action): + def __init__(self, option_strings, dest, default=None, **kwargs): + super().__init__(option_strings, dest, default=default or {}, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + key, value = values + getattr(namespace, self.dest)[key] = value + + def make_cmdline_parser() -> argparse.ArgumentParser: ####################################################################### # Command line interface definition @@ -394,6 +429,16 @@ def make_cmdline_parser() -> argparse.ArgumentParser: help="path to a Python source file to use as the function", ) + auto_annotate_task_parser.add_argument( + "--function-parameter", + "-p", + metavar="NAME=TYPE:VALUE", + type=parse_function_parameter, + action=BuildDictAction, + dest="function_parameters", + help="parameter for the function", + ) + auto_annotate_task_parser.add_argument( "--clear-existing", action="store_true", help="Remove existing annotations from the task" ) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py new file mode 100644 index 00000000000..57457d74225 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py @@ -0,0 +1,41 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from functools import cached_property +from typing import List + +import PIL.Image +import torchvision.models + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models + + +class _TorchvisionDetectionFunction: + def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: + weights_enum = torchvision.models.get_model_weights(model_name) + self._weights = weights_enum[weights_name] + self._transforms = self._weights.transforms() + self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) + self._model.eval() + + @cached_property + def spec(self) -> cvataa.DetectionFunctionSpec: + return cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"]) + ] + ) + + def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + results = self._model([self._transforms(image)]) + + return [ + cvataa.rectangle(label.item(), [x.item() for x in box]) + for result in results + for box, label in zip(result["boxes"], result["labels"]) + ] + + +create = _TorchvisionDetectionFunction diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py new file mode 100644 index 00000000000..b4eb47d476d --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py @@ -0,0 +1,59 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from functools import cached_property +from typing import List + +import PIL.Image +import torchvision.models + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models + + +class _TorchvisionKeypointDetectionFunction: + def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: + weights_enum = torchvision.models.get_model_weights(model_name) + self._weights = weights_enum[weights_name] + self._transforms = self._weights.transforms() + self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) + self._model.eval() + + @cached_property + def spec(self) -> cvataa.DetectionFunctionSpec: + return cvataa.DetectionFunctionSpec( + labels=[ + cvataa.skeleton_label_spec( + cat, + i, + [ + cvataa.keypoint_spec(name, j) + for j, name in enumerate(self._weights.meta["keypoint_names"]) + ], + ) + for i, cat in enumerate(self._weights.meta["categories"]) + ] + ) + + def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + results = self._model([self._transforms(image)]) + + return [ + cvataa.skeleton( + label.item(), + elements=[ + cvataa.keypoint( + keypoint_id, + [keypoint[0].item(), keypoint[1].item()], + occluded=not keypoint[2].item(), + ) + for keypoint_id, keypoint in enumerate(keypoints) + ], + ) + for result in results + for keypoints, label in zip(result["keypoints"], result["labels"]) + ] + + +create = _TorchvisionKeypointDetectionFunction diff --git a/site/content/en/docs/api_sdk/cli/_index.md b/site/content/en/docs/api_sdk/cli/_index.md index 1067d934a14..bca0a070264 100644 --- a/site/content/en/docs/api_sdk/cli/_index.md +++ b/site/content/en/docs/api_sdk/cli/_index.md @@ -235,19 +235,52 @@ by using the [label constructor](/docs/manual/basics/creating_an_annotation_task This command provides a command-line interface to the [auto-annotation API](/docs/api_sdk/sdk/auto-annotation). -To use it, create a Python module that implements the AA function protocol. -In other words, this module must define the required attributes on the module level. -For example: +It can auto-annotate using AA functions implemented in one of the following ways: -```python -import cvat_sdk.auto_annotation as cvataa +1. As a Python module directly implementing the AA function protocol. + Such a module must define the required attributes at the module level. -spec = cvataa.DetectionFunctionSpec(...) + For example: -def detect(context, image): - ... -``` + ```python + import cvat_sdk.auto_annotation as cvataa + + spec = cvataa.DetectionFunctionSpec(...) + + def detect(context, image): + ... + ``` + +1. As a Python module implementing a factory function named `create`. + This function must return an object implementing the AA function protocol. + Any parameters specified on the command line using the `-p` option + will be passed to `create`. + + For example: + + ```python + import cvat_sdk.auto_annotation as cvataa + + class _MyFunction: + def __init__(...): + ... + + spec = cvataa.DetectionFunctionSpec(...) + + def detect(context, image): + ... + + def create(...) -> cvataa.DetectionFunction: + return _MyFunction(...) + ``` + +- Annotate the task with id 137 with the predefined torchvision detection function, + which is parameterized: + ```bash + cvat-cli auto-annotate 137 --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:fasterrcnn_resnet50_fpn_v2 -p box_score_thresh=float:0.5 + ``` - Annotate the task with id 138 with an AA function defined in `my_func.py`: ```bash diff --git a/site/content/en/docs/api_sdk/sdk/auto-annotation.md b/site/content/en/docs/api_sdk/sdk/auto-annotation.md index d6281f7e168..b85ab7b067b 100644 --- a/site/content/en/docs/api_sdk/sdk/auto-annotation.md +++ b/site/content/en/docs/api_sdk/sdk/auto-annotation.md @@ -197,3 +197,57 @@ Same logic applies to sub-label IDs. `annotate_task` will raise a `BadFunctionError` exception if it detects that the function violated the AA function protocol. + +## Predefined AA functions + +This layer includes several predefined AA functions. +You can use them as-is, or as a base on which to build your own. + +Each function is implemented as a module +to allow usage via the CLI `auto-annotate` command. +Therefore, in order to use it from the SDK, +you'll need to import the corresponding module. + +### `cvat_sdk.auto_annotation.functions.torchvision_detection` + +This AA function uses object detection models from +the [torchvision](https://pytorch.org/vision/stable/index.html) library. +It produces rectangle annotations. + +To use it, install CVAT SDK with the `pytorch` extra: + +``` +$ pip install "cvat-sdk[pytorch]" +``` + +Usage from Python: + +```python +from cvat_sdk.auto_annotation.functions.torchvision_detection import create as create_torchvision +annotate_task(, , create_torchvision(, ...)) +``` + +Usage from the CLI: + +```bash +cvat-cli auto-annotate "" --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:"" ... +``` + +The `create` function accepts the following parameters: + +- `model_name` (`str`) - the name of the model, such as `fasterrcnn_resnet50_fpn_v2`. + This parameter is required. +- `weights_name` (`str`) - the name of a weights enum value for the model, such as `COCO_V1`. + Defaults to `DEFAULT`. + +It also accepts arbitrary additional parameters, +which are passed directly to the model constructor. + +### `cvat_sdk.auto_annotation.functions.torchvision_keypoint_detection` + +This AA function is analogous to `torchvision_detection`, +except it uses torchvision's keypoint detection models and produces skeleton annotations. +Keypoints which the model marks as invisible will be marked as occluded in CVAT. + +Refer to the previous section for usage instructions and parameter information. diff --git a/tests/python/cli/example_parameterized_function.py b/tests/python/cli/example_parameterized_function.py new file mode 100644 index 00000000000..29d9038e78b --- /dev/null +++ b/tests/python/cli/example_parameterized_function.py @@ -0,0 +1,32 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from types import SimpleNamespace as namespace +from typing import List + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +import PIL.Image + + +def create(s: str, i: int, f: float, b: bool) -> cvataa.DetectionFunction: + assert s == "string" + assert i == 123 + assert f == 5.5 + assert b is False + + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 0), + ], + ) + + def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> List[models.LabeledShapeRequest]: + return [ + cvataa.rectangle(0, [1, 2, 3, 4]), + ] + + return namespace(spec=spec, detect=detect) diff --git a/tests/python/cli/test_cli.py b/tests/python/cli/test_cli.py index fbb6f73fe5f..66749f992aa 100644 --- a/tests/python/cli/test_cli.py +++ b/tests/python/cli/test_cli.py @@ -328,3 +328,20 @@ def test_auto_annotate_with_file(self, fxt_new_task: Task): annotations = fxt_new_task.get_annotations() assert annotations.shapes + + def test_auto_annotate_with_parameters(self, fxt_new_task: Task): + annotations = fxt_new_task.get_annotations() + assert not annotations.shapes + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.example_parameterized_function", + "-ps=str:string", + "-pi=int:123", + "-pf=float:5.5", + "-pb=bool:false", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index c41b655d6c5..142c4354c4d 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -18,6 +18,11 @@ from .util import make_pbar +try: + import torchvision.models as torchvision_models +except ModuleNotFoundError: + torchvision_models = None + @pytest.fixture(autouse=True) def _common_setup( @@ -553,3 +558,157 @@ def test_non_skeleton_with_elements(self): ], "non-skeleton shape with elements", ) + + +if torchvision_models is not None: + import torch + import torch.nn as nn + + class FakeTorchvisionDetector(nn.Module): + def __init__(self, label_id: int) -> None: + super().__init__() + self._label_id = label_id + + def forward(self, images: List[torch.Tensor]) -> List[dict]: + assert isinstance(images, list) + assert all(isinstance(t, torch.Tensor) for t in images) + + return [ + { + "boxes": torch.tensor([[1, 2, 3, 4]]), + "labels": torch.tensor([self._label_id]), + } + ] + + def fake_get_detection_model(name: str, weights, test_param): + assert test_param == "expected_value" + + car_label_id = weights.meta["categories"].index("car") + + return FakeTorchvisionDetector(label_id=car_label_id) + + class FakeTorchvisionKeypointDetector(nn.Module): + def __init__(self, label_id: int, keypoint_names: List[str]) -> None: + super().__init__() + self._label_id = label_id + self._keypoint_names = keypoint_names + + def forward(self, images: List[torch.Tensor]) -> List[dict]: + assert isinstance(images, list) + assert all(isinstance(t, torch.Tensor) for t in images) + + return [ + { + "labels": torch.tensor([self._label_id]), + "keypoints": torch.tensor( + [ + [ + [hash(name) % 100, 0, 1 if name.startswith("right_") else 0] + for i, name in enumerate(self._keypoint_names) + ] + ] + ), + } + ] + + def fake_get_keypoint_detection_model(name: str, weights, test_param): + assert test_param == "expected_value" + + person_label_id = weights.meta["categories"].index("person") + + return FakeTorchvisionKeypointDetector( + label_id=person_label_id, keypoint_names=weights.meta["keypoint_names"] + ) + + +@pytest.mark.skipif(torchvision_models is None, reason="torchvision is not installed") +class TestAutoAnnotationFunctions: + @pytest.fixture(autouse=True) + def setup( + self, + tmp_path: Path, + fxt_login: Tuple[Client, str], + ): + self.client = fxt_login[0] + self.image = generate_image_file("1.png", size=(100, 100)) + + image_dir = tmp_path / "images" + image_dir.mkdir() + + image_path = image_dir / self.image.name + image_path.write_bytes(self.image.getbuffer()) + + self.task = self.client.tasks.create_from_data( + models.TaskWriteRequest( + "Auto-annotation test task", + labels=[ + models.PatchedLabelRequest( + name="person", + type="skeleton", + sublabels=[ + models.SublabelRequest(name="left_eye"), + models.SublabelRequest(name="right_eye"), + ], + ), + models.PatchedLabelRequest(name="car"), + ], + ), + resources=[image_path], + ) + + task_labels = self.task.get_labels() + self.task_labels_by_id = {label.id: label for label in task_labels} + + person_label = next(label for label in task_labels if label.name == "person") + self.person_sublabels_by_id = {sl.id: sl for sl in person_label.sublabels} + + def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(torchvision_models, "get_model", fake_get_detection_model) + + import cvat_sdk.auto_annotation.functions.torchvision_detection as td + + cvataa.annotate_task( + self.client, + self.task.id, + td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"), + allow_unmatched_labels=True, + ) + + annotations = self.task.get_annotations() + + assert len(annotations.shapes) == 1 + assert self.task_labels_by_id[annotations.shapes[0].label_id].name == "car" + assert annotations.shapes[0].type.value == "rectangle" + assert annotations.shapes[0].points == [1, 2, 3, 4] + + def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(torchvision_models, "get_model", fake_get_keypoint_detection_model) + + import cvat_sdk.auto_annotation.functions.torchvision_keypoint_detection as tkd + + cvataa.annotate_task( + self.client, + self.task.id, + tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"), + allow_unmatched_labels=True, + ) + + annotations = self.task.get_annotations() + + assert len(annotations.shapes) == 1 + assert self.task_labels_by_id[annotations.shapes[0].label_id].name == "person" + assert annotations.shapes[0].type.value == "skeleton" + assert len(annotations.shapes[0].elements) == 2 + + elements = sorted( + annotations.shapes[0].elements, + key=lambda e: self.person_sublabels_by_id[e.label_id].name, + ) + + assert self.person_sublabels_by_id[elements[0].label_id].name == "left_eye" + assert elements[0].points[0] == hash("left_eye") % 100 + assert elements[0].occluded + + assert self.person_sublabels_by_id[elements[1].label_id].name == "right_eye" + assert elements[1].points[0] == hash("right_eye") % 100 + assert not elements[1].occluded