Skip to content

Commit

Permalink
SDK: Add predefined functions based on torchvision (#6649)
Browse files Browse the repository at this point in the history
These serve as a replacement for YOLOv8n that was removed in #6632.

To support these functions, I also add an ability to define
parameterized functions for use with the CLI.
  • Loading branch information
SpecLad authored Aug 11, 2023
1 parent 096870e commit b0b8e49
Show file tree
Hide file tree
Showing 10 changed files with 460 additions and 11 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
functionality for automatically annotating a task by running a
user-provided function on the local machine, and a corresponding CLI command
(`auto-annotate`)
(<https://github.com/opencv/cvat/pull/6483>)
(<https://github.com/opencv/cvat/pull/6483>,
<https://github.com/opencv/cvat/pull/6649>)
- Cached frames indication on the interface (<https://github.com/opencv/cvat/pull/6586>)

### Changed
Expand Down
10 changes: 9 additions & 1 deletion cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib.util
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple

import cvat_sdk.auto_annotation as cvataa
from cvat_sdk import Client, models
Expand Down Expand Up @@ -151,6 +151,7 @@ def tasks_auto_annotate(
*,
function_module: Optional[str] = None,
function_file: Optional[Path] = None,
function_parameters: Dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
) -> None:
Expand All @@ -163,6 +164,13 @@ def tasks_auto_annotate(
else:
assert False, "function identification arguments missing"

if hasattr(function, "create"):
# this is actually a function factory
function = function.create(**function_parameters)
else:
if function_parameters:
raise TypeError("function takes no parameters")

cvataa.annotate_task(
self.client,
task_id,
Expand Down
45 changes: 45 additions & 0 deletions cvat-cli/src/cvat_cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import textwrap
from distutils.util import strtobool
from pathlib import Path
from typing import Any, Tuple

from cvat_sdk.core.proxies.tasks import ResourceType

Expand Down Expand Up @@ -41,6 +42,40 @@ def parse_resource_type(s: str) -> ResourceType:
return s


def parse_function_parameter(s: str) -> Tuple[str, Any]:
key, sep, type_and_value = s.partition("=")

if not sep:
raise argparse.ArgumentTypeError("parameter value not specified")

type_, sep, value = type_and_value.partition(":")

if not sep:
raise argparse.ArgumentTypeError("parameter type not specified")

if type_ == "int":
value = int(value)
elif type_ == "float":
value = float(value)
elif type_ == "str":
pass
elif type_ == "bool":
value = bool(strtobool(value))
else:
raise argparse.ArgumentTypeError(f"unsupported parameter type {type_!r}")

return (key, value)


class BuildDictAction(argparse.Action):
def __init__(self, option_strings, dest, default=None, **kwargs):
super().__init__(option_strings, dest, default=default or {}, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
key, value = values
getattr(namespace, self.dest)[key] = value


def make_cmdline_parser() -> argparse.ArgumentParser:
#######################################################################
# Command line interface definition
Expand Down Expand Up @@ -394,6 +429,16 @@ def make_cmdline_parser() -> argparse.ArgumentParser:
help="path to a Python source file to use as the function",
)

auto_annotate_task_parser.add_argument(
"--function-parameter",
"-p",
metavar="NAME=TYPE:VALUE",
type=parse_function_parameter,
action=BuildDictAction,
dest="function_parameters",
help="parameter for the function",
)

auto_annotate_task_parser.add_argument(
"--clear-existing", action="store_true", help="Remove existing annotations from the task"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from functools import cached_property
from typing import List

import PIL.Image
import torchvision.models

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models


class _TorchvisionDetectionFunction:
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()

@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"])
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
results = self._model([self._transforms(image)])

return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label in zip(result["boxes"], result["labels"])
]


create = _TorchvisionDetectionFunction
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from functools import cached_property
from typing import List

import PIL.Image
import torchvision.models

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models


class _TorchvisionKeypointDetectionFunction:
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()

@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
labels=[
cvataa.skeleton_label_spec(
cat,
i,
[
cvataa.keypoint_spec(name, j)
for j, name in enumerate(self._weights.meta["keypoint_names"])
],
)
for i, cat in enumerate(self._weights.meta["categories"])
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
results = self._model([self._transforms(image)])

return [
cvataa.skeleton(
label.item(),
elements=[
cvataa.keypoint(
keypoint_id,
[keypoint[0].item(), keypoint[1].item()],
occluded=not keypoint[2].item(),
)
for keypoint_id, keypoint in enumerate(keypoints)
],
)
for result in results
for keypoints, label in zip(result["keypoints"], result["labels"])
]


create = _TorchvisionKeypointDetectionFunction
51 changes: 42 additions & 9 deletions site/content/en/docs/api_sdk/cli/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,52 @@ by using the [label constructor](/docs/manual/basics/creating_an_annotation_task

This command provides a command-line interface
to the [auto-annotation API](/docs/api_sdk/sdk/auto-annotation).
To use it, create a Python module that implements the AA function protocol.

In other words, this module must define the required attributes on the module level.
For example:
It can auto-annotate using AA functions implemented in one of the following ways:

```python
import cvat_sdk.auto_annotation as cvataa
1. As a Python module directly implementing the AA function protocol.
Such a module must define the required attributes at the module level.

spec = cvataa.DetectionFunctionSpec(...)
For example:

def detect(context, image):
...
```
```python
import cvat_sdk.auto_annotation as cvataa

spec = cvataa.DetectionFunctionSpec(...)

def detect(context, image):
...
```

1. As a Python module implementing a factory function named `create`.
This function must return an object implementing the AA function protocol.
Any parameters specified on the command line using the `-p` option
will be passed to `create`.

For example:

```python
import cvat_sdk.auto_annotation as cvataa

class _MyFunction:
def __init__(...):
...

spec = cvataa.DetectionFunctionSpec(...)

def detect(context, image):
...

def create(...) -> cvataa.DetectionFunction:
return _MyFunction(...)
```

- Annotate the task with id 137 with the predefined torchvision detection function,
which is parameterized:
```bash
cvat-cli auto-annotate 137 --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \
-p model_name=str:fasterrcnn_resnet50_fpn_v2 -p box_score_thresh=float:0.5
```

- Annotate the task with id 138 with an AA function defined in `my_func.py`:
```bash
Expand Down
54 changes: 54 additions & 0 deletions site/content/en/docs/api_sdk/sdk/auto-annotation.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,57 @@ Same logic applies to sub-label IDs.

`annotate_task` will raise a `BadFunctionError` exception
if it detects that the function violated the AA function protocol.

## Predefined AA functions

This layer includes several predefined AA functions.
You can use them as-is, or as a base on which to build your own.

Each function is implemented as a module
to allow usage via the CLI `auto-annotate` command.
Therefore, in order to use it from the SDK,
you'll need to import the corresponding module.

### `cvat_sdk.auto_annotation.functions.torchvision_detection`

This AA function uses object detection models from
the [torchvision](https://pytorch.org/vision/stable/index.html) library.
It produces rectangle annotations.

To use it, install CVAT SDK with the `pytorch` extra:

```
$ pip install "cvat-sdk[pytorch]"
```

Usage from Python:

```python
from cvat_sdk.auto_annotation.functions.torchvision_detection import create as create_torchvision
annotate_task(<client>, <task ID>, create_torchvision(<model name>, ...))
```

Usage from the CLI:

```bash
cvat-cli auto-annotate "<task ID>" --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \
-p model_name=str:"<model name>" ...
```

The `create` function accepts the following parameters:

- `model_name` (`str`) - the name of the model, such as `fasterrcnn_resnet50_fpn_v2`.
This parameter is required.
- `weights_name` (`str`) - the name of a weights enum value for the model, such as `COCO_V1`.
Defaults to `DEFAULT`.

It also accepts arbitrary additional parameters,
which are passed directly to the model constructor.

### `cvat_sdk.auto_annotation.functions.torchvision_keypoint_detection`

This AA function is analogous to `torchvision_detection`,
except it uses torchvision's keypoint detection models and produces skeleton annotations.
Keypoints which the model marks as invisible will be marked as occluded in CVAT.

Refer to the previous section for usage instructions and parameter information.
32 changes: 32 additions & 0 deletions tests/python/cli/example_parameterized_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from types import SimpleNamespace as namespace
from typing import List

import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
import PIL.Image


def create(s: str, i: int, f: float, b: bool) -> cvataa.DetectionFunction:
assert s == "string"
assert i == 123
assert f == 5.5
assert b is False

spec = cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec("car", 0),
],
)

def detect(
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> List[models.LabeledShapeRequest]:
return [
cvataa.rectangle(0, [1, 2, 3, 4]),
]

return namespace(spec=spec, detect=detect)
17 changes: 17 additions & 0 deletions tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,20 @@ def test_auto_annotate_with_file(self, fxt_new_task: Task):

annotations = fxt_new_task.get_annotations()
assert annotations.shapes

def test_auto_annotate_with_parameters(self, fxt_new_task: Task):
annotations = fxt_new_task.get_annotations()
assert not annotations.shapes

self.run_cli(
"auto-annotate",
str(fxt_new_task.id),
f"--function-module={__package__}.example_parameterized_function",
"-ps=str:string",
"-pi=int:123",
"-pf=float:5.5",
"-pb=bool:false",
)

annotations = fxt_new_task.get_annotations()
assert annotations.shapes
Loading

0 comments on commit b0b8e49

Please sign in to comment.