-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SDK: Add predefined functions based on torchvision (#6649)
These serve as a replacement for YOLOv8n that was removed in #6632. To support these functions, I also add an ability to define parameterized functions for use with the CLI.
- Loading branch information
Showing
10 changed files
with
460 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from functools import cached_property | ||
from typing import List | ||
|
||
import PIL.Image | ||
import torchvision.models | ||
|
||
import cvat_sdk.auto_annotation as cvataa | ||
import cvat_sdk.models as models | ||
|
||
|
||
class _TorchvisionDetectionFunction: | ||
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: | ||
weights_enum = torchvision.models.get_model_weights(model_name) | ||
self._weights = weights_enum[weights_name] | ||
self._transforms = self._weights.transforms() | ||
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) | ||
self._model.eval() | ||
|
||
@cached_property | ||
def spec(self) -> cvataa.DetectionFunctionSpec: | ||
return cvataa.DetectionFunctionSpec( | ||
labels=[ | ||
cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"]) | ||
] | ||
) | ||
|
||
def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: | ||
results = self._model([self._transforms(image)]) | ||
|
||
return [ | ||
cvataa.rectangle(label.item(), [x.item() for x in box]) | ||
for result in results | ||
for box, label in zip(result["boxes"], result["labels"]) | ||
] | ||
|
||
|
||
create = _TorchvisionDetectionFunction |
59 changes: 59 additions & 0 deletions
59
cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from functools import cached_property | ||
from typing import List | ||
|
||
import PIL.Image | ||
import torchvision.models | ||
|
||
import cvat_sdk.auto_annotation as cvataa | ||
import cvat_sdk.models as models | ||
|
||
|
||
class _TorchvisionKeypointDetectionFunction: | ||
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: | ||
weights_enum = torchvision.models.get_model_weights(model_name) | ||
self._weights = weights_enum[weights_name] | ||
self._transforms = self._weights.transforms() | ||
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) | ||
self._model.eval() | ||
|
||
@cached_property | ||
def spec(self) -> cvataa.DetectionFunctionSpec: | ||
return cvataa.DetectionFunctionSpec( | ||
labels=[ | ||
cvataa.skeleton_label_spec( | ||
cat, | ||
i, | ||
[ | ||
cvataa.keypoint_spec(name, j) | ||
for j, name in enumerate(self._weights.meta["keypoint_names"]) | ||
], | ||
) | ||
for i, cat in enumerate(self._weights.meta["categories"]) | ||
] | ||
) | ||
|
||
def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: | ||
results = self._model([self._transforms(image)]) | ||
|
||
return [ | ||
cvataa.skeleton( | ||
label.item(), | ||
elements=[ | ||
cvataa.keypoint( | ||
keypoint_id, | ||
[keypoint[0].item(), keypoint[1].item()], | ||
occluded=not keypoint[2].item(), | ||
) | ||
for keypoint_id, keypoint in enumerate(keypoints) | ||
], | ||
) | ||
for result in results | ||
for keypoints, label in zip(result["keypoints"], result["labels"]) | ||
] | ||
|
||
|
||
create = _TorchvisionKeypointDetectionFunction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
from types import SimpleNamespace as namespace | ||
from typing import List | ||
|
||
import cvat_sdk.auto_annotation as cvataa | ||
import cvat_sdk.models as models | ||
import PIL.Image | ||
|
||
|
||
def create(s: str, i: int, f: float, b: bool) -> cvataa.DetectionFunction: | ||
assert s == "string" | ||
assert i == 123 | ||
assert f == 5.5 | ||
assert b is False | ||
|
||
spec = cvataa.DetectionFunctionSpec( | ||
labels=[ | ||
cvataa.label_spec("car", 0), | ||
], | ||
) | ||
|
||
def detect( | ||
context: cvataa.DetectionFunctionContext, image: PIL.Image.Image | ||
) -> List[models.LabeledShapeRequest]: | ||
return [ | ||
cvataa.rectangle(0, [1, 2, 3, 4]), | ||
] | ||
|
||
return namespace(spec=spec, detect=detect) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.