diff --git a/changelog.d/20241212_193004_roman_cli_agent.md b/changelog.d/20241212_193004_roman_cli_agent.md new file mode 100644 index 000000000000..f7fd8c0a5be4 --- /dev/null +++ b/changelog.d/20241212_193004_roman_cli_agent.md @@ -0,0 +1,4 @@ +### Added + +- \[CLI\] Added commands for working with native functions + () diff --git a/cvat-cli/README.md b/cvat-cli/README.md index bbd98c0980c9..fcee05dae1c4 100644 --- a/cvat-cli/README.md +++ b/cvat-cli/README.md @@ -22,6 +22,11 @@ The following subcommands are supported: - `backup` - back up a task - `auto-annotate` - automatically annotate a task using a local function +- Functions (Enterprise/Cloud only): + - `create-native` - create a function that can be powered by an agent + - `delete` - delete a function + - `run-agent` - process requests for a native function + ## Installation `pip install cvat-cli` diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt index 94b064e0ace5..a53fd13b472e 100644 --- a/cvat-cli/requirements/base.txt +++ b/cvat-cli/requirements/base.txt @@ -1,3 +1,5 @@ -cvat-sdk~=2.24.1 +cvat-sdk==2.24.1 + +attrs>=24.2.0 Pillow>=10.3.0 setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/cvat-cli/src/cvat_cli/__main__.py b/cvat-cli/src/cvat_cli/__main__.py index c93569182c08..7c649747cb31 100755 --- a/cvat-cli/src/cvat_cli/__main__.py +++ b/cvat-cli/src/cvat_cli/__main__.py @@ -11,7 +11,12 @@ from cvat_sdk import exceptions from ._internal.commands_all import COMMANDS -from ._internal.common import build_client, configure_common_arguments, configure_logger +from ._internal.common import ( + CriticalError, + build_client, + configure_common_arguments, + configure_logger, +) from ._internal.utils import popattr logger = logging.getLogger(__name__) @@ -29,7 +34,7 @@ def main(args: list[str] = None): try: with build_client(parsed_args, logger=logger) as client: popattr(parsed_args, "_executor")(client, **vars(parsed_args)) - except (exceptions.ApiException, urllib3.exceptions.HTTPError) as e: + except (exceptions.ApiException, urllib3.exceptions.HTTPError, CriticalError) as e: logger.critical(e) return 1 diff --git a/cvat-cli/src/cvat_cli/_internal/agent.py b/cvat-cli/src/cvat_cli/_internal/agent.py new file mode 100644 index 000000000000..820a758e54d2 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/agent.py @@ -0,0 +1,351 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import concurrent.futures +import json +import multiprocessing +import random +import secrets +import shutil +import tempfile +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.datasets as cvatds +import urllib3.exceptions +from cvat_sdk import Client, models +from cvat_sdk.auto_annotation.driver import ( + _AnnotationMapper, + _DetectionFunctionContextImpl, + _LabelNameMapping, + _SpecNameMapping, +) +from cvat_sdk.exceptions import ApiException + +from .common import CriticalError, FunctionLoader + +FUNCTION_PROVIDER_NATIVE = "native" +FUNCTION_KIND_DETECTOR = "detector" + +_POLLING_INTERVAL_MEAN = timedelta(seconds=60) +_POLLING_INTERVAL_MAX_OFFSET = timedelta(seconds=10) + +_UPDATE_INTERVAL = timedelta(seconds=30) + + +class _RecoverableExecutor: + # A wrapper around ProcessPoolExecutor that recreates the underlying + # executor when a worker crashes. + def __init__(self, initializer, initargs): + self._mp_context = multiprocessing.get_context("spawn") + self._initializer = initializer + self._initargs = initargs + self._executor = self._new_executor() + + def _new_executor(self): + return concurrent.futures.ProcessPoolExecutor( + max_workers=1, + mp_context=self._mp_context, + initializer=self._initializer, + initargs=self._initargs, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._executor.shutdown() + + def submit(self, func, /, *args, **kwargs): + return self._executor.submit(func, *args, **kwargs) + + def result(self, future: concurrent.futures.Future): + try: + return future.result() + except concurrent.futures.BrokenExecutor: + self._executor.shutdown() + self._executor = self._new_executor() + raise + + +_current_function: cvataa.DetectionFunction + + +def _worker_init(function_loader: FunctionLoader): + global _current_function + _current_function = function_loader.load() + + +def _worker_job_get_function_spec(): + return _current_function.spec + + +def _worker_job_detect(context, image): + return _current_function.detect(context, image) + + +class _Agent: + def __init__(self, client: Client, executor: _RecoverableExecutor, function_id: int): + self._rng = random.Random() # nosec + + self._client = client + self._executor = executor + self._function_id = function_id + self._function_spec = self._executor.result( + self._executor.submit(_worker_job_get_function_spec) + ) + + _, response = self._client.api_client.call_api( + "/api/functions/{function_id}", + "GET", + path_params={"function_id": self._function_id}, + ) + + remote_function = json.loads(response.data) + + self._validate_function_compatibility(remote_function) + + self._agent_id = secrets.token_hex(16) + self._client.logger.info("Agent starting with ID %r", self._agent_id) + + self._cached_task_id = None + + def _validate_function_compatibility(self, remote_function: dict) -> None: + function_id = remote_function["id"] + + if remote_function["provider"] != FUNCTION_PROVIDER_NATIVE: + raise CriticalError( + f"Function #{function_id} has provider {remote_function['provider']!r}. " + f"Agents can only be run for functions with provider {FUNCTION_PROVIDER_NATIVE!r}." + ) + + if isinstance(self._function_spec, cvataa.DetectionFunctionSpec): + self._validate_detection_function_compatibility(remote_function) + self._calculate_result_for_ar = self._calculate_result_for_detection_ar + else: + raise CriticalError( + f"Unsupported function spec type: {type(self._function_spec).__name__}" + ) + + def _validate_detection_function_compatibility(self, remote_function: dict) -> None: + incompatible_msg = ( + f"Function #{remote_function['id']} is incompatible with function object: " + ) + + if remote_function["kind"] != FUNCTION_KIND_DETECTOR: + raise CriticalError( + incompatible_msg + + f"kind is {remote_function['kind']!r} (expected {FUNCTION_KIND_DETECTOR!r})." + ) + + labels_by_name = {label.name: label for label in self._function_spec.labels} + + for remote_label in remote_function["labels_v2"]: + label = labels_by_name.get(remote_label["name"]) + + if not label: + raise CriticalError( + incompatible_msg + f"label {remote_label['name']!r} is not supported." + ) + + if ( + remote_label["type"] not in {"any", "unknown"} + and remote_label["type"] != label.type + ): + raise CriticalError( + incompatible_msg + + f"label {remote_label['name']!r} has type {remote_label['type']!r}, " + f"but the function object expects type {label.type!r}." + ) + + if remote_label["attributes"]: + raise CriticalError( + incompatible_msg + + f"label {remote_label['name']!r} has attributes, which is not supported." + ) + + def _wait_between_polls(self): + # offset the interval randomly to avoid synchronization between workers + max_offset_sec = _POLLING_INTERVAL_MAX_OFFSET.total_seconds() + offset_sec = self._rng.uniform(-max_offset_sec, max_offset_sec) + time.sleep(_POLLING_INTERVAL_MEAN.total_seconds() + offset_sec) + + def run(self, *, burst: bool) -> None: + if burst: + while ar_assignment := self._poll_for_ar(): + self._process_ar(ar_assignment) + self._client.logger.info("No annotation requests left in queue; exiting.") + else: + while True: + if ar_assignment := self._poll_for_ar(): + self._process_ar(ar_assignment) + else: + self._wait_between_polls() + + def _process_ar(self, ar_assignment: dict) -> None: + self._client.logger.info("Got annotation request assignment: %r", ar_assignment) + + ar_id = ar_assignment["ar_id"] + + try: + result = self._calculate_result_for_ar(ar_id, ar_assignment["ar_params"]) + + self._client.logger.info("Submitting result for AR %r...", ar_id) + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/complete", + "POST", + path_params={"queue_id": f"function:{self._function_id}", "request_id": ar_id}, + body={"agent_id": self._agent_id, "annotations": result}, + ) + self._client.logger.info("AR %r completed", ar_id) + except Exception as ex: + self._client.logger.error("Failed to process AR %r", ar_id, exc_info=True) + + # Arbitrary exceptions may contain details of the client's system or code, which + # shouldn't be exposed to the server (and to users of the function). + # Therefore, we only produce a limited amount of detail, and only in known failure cases. + error_message = "Unknown error" + + if isinstance(ex, ApiException): + if ex.status: + error_message = f"Received HTTP status {ex.status}" + else: + error_message = "Failed an API call" + elif isinstance(ex, urllib3.exceptions.RequestError): + if isinstance(ex, urllib3.exceptions.MaxRetryError): + ex_type = type(ex.reason) + else: + ex_type = type(ex) + + error_message = f"Failed to make an HTTP request to {ex.url} ({ex_type.__name__})" + elif isinstance(ex, urllib3.exceptions.HTTPError): + error_message = "Failed to make an HTTP request" + elif isinstance(ex, cvataa.BadFunctionError): + error_message = "Underlying function returned incorrect result: " + str(ex) + elif isinstance(ex, concurrent.futures.BrokenExecutor): + error_message = "Worker process crashed" + + try: + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/fail", + "POST", + path_params={ + "queue_id": f"function:{self._function_id}", + "request_id": ar_id, + }, + body={"agent_id": self._agent_id, "exc_info": error_message}, + ) + except Exception: + self._client.logger.error("Couldn't fail AR %r", ar_id, exc_info=True) + else: + self._client.logger.info("AR %r failed", ar_id) + + def _poll_for_ar(self) -> Optional[dict]: + while True: + self._client.logger.info("Trying to acquire an annotation request...") + try: + _, response = self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/acquire", + "POST", + path_params={"queue_id": f"function:{self._function_id}"}, + body={"agent_id": self._agent_id, "request_category": "batch"}, + ) + break + except (urllib3.exceptions.HTTPError, ApiException) as ex: + if isinstance(ex, ApiException) and ex.status and 400 <= ex.status < 500: + # We did something wrong; no point in retrying. + raise + + self._client.logger.error("Acquire request failed; will retry", exc_info=True) + self._wait_between_polls() + + response_data = json.loads(response.data) + return response_data["ar_assignment"] + + def _calculate_result_for_detection_ar( + self, ar_id: str, ar_params + ) -> models.PatchedLabeledDataRequest: + if ar_params["type"] != "annotate_task": + raise RuntimeError(f"Unsupported AR type: {ar_params['type']!r}") + + if ar_params["task"] != self._cached_task_id: + # To avoid uncontrolled disk usage, + # we'll only keep one task in the cache at a time. + self._client.logger.info("Switched to a new task; clearing the cache...") + if self._client.config.cache_dir.exists(): + shutil.rmtree(self._client.config.cache_dir) + + ds = cvatds.TaskDataset(self._client, ar_params["task"], load_annotations=False) + + self._cached_task_id = ar_params["task"] + + # Fetching the dataset might take a while, so do a progress update to let the server + # know we're still alive. + self._update_ar(ar_id, 0) + last_update_timestamp = datetime.now(tz=timezone.utc) + + mapping = ar_params["mapping"] + conv_mask_to_poly = ar_params["conv_mask_to_poly"] + + spec_nm = _SpecNameMapping( + labels={k: _LabelNameMapping(v["name"]) for k, v in mapping.items()} + ) + + mapper = _AnnotationMapper( + self._client.logger, + self._function_spec.labels, + ds.labels, + allow_unmatched_labels=False, + spec_nm=spec_nm, + conv_mask_to_poly=conv_mask_to_poly, + ) + + all_annotations = models.PatchedLabeledDataRequest(shapes=[]) + + for sample_index, sample in enumerate(ds.samples): + context = _DetectionFunctionContextImpl( + frame_name=sample.frame_name, + conf_threshold=ar_params["threshold"], + conv_mask_to_poly=conv_mask_to_poly, + ) + shapes = self._executor.result( + self._executor.submit(_worker_job_detect, context, sample.media.load_image()) + ) + + mapper.validate_and_remap(shapes, sample.frame_index) + all_annotations.shapes.extend(shapes) + + current_timestamp = datetime.now(tz=timezone.utc) + + if current_timestamp >= last_update_timestamp + _UPDATE_INTERVAL: + self._update_ar(ar_id, (sample_index + 1) / len(ds.samples)) + last_update_timestamp = current_timestamp + + return all_annotations + + def _update_ar(self, ar_id: str, progress: float) -> None: + self._client.logger.info("Updating AR %r progress to %.2f%%", ar_id, progress * 100) + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/update", + "POST", + path_params={"queue_id": f"function:{self._function_id}", "request_id": ar_id}, + body={"agent_id": self._agent_id, "progress": progress}, + ) + + +def run_agent( + client: Client, function_loader: FunctionLoader, function_id: int, *, burst: bool +) -> None: + with ( + _RecoverableExecutor(initializer=_worker_init, initargs=[function_loader]) as executor, + tempfile.TemporaryDirectory() as cache_dir, + ): + client.config.cache_dir = Path(cache_dir, "cache") + client.logger.info("Will store cache at %s", client.config.cache_dir) + + agent = _Agent(client, executor, function_id) + agent.run(burst=burst) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_all.py b/cvat-cli/src/cvat_cli/_internal/commands_all.py index 758d6b1d05e8..5f293f0ce06f 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands_all.py +++ b/cvat-cli/src/cvat_cli/_internal/commands_all.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: MIT from .command_base import CommandGroup, DeprecatedAlias +from .commands_functions import COMMANDS as COMMANDS_FUNCTIONS from .commands_projects import COMMANDS as COMMANDS_PROJECTS from .commands_tasks import COMMANDS as COMMANDS_TASKS COMMANDS = CommandGroup(description="Perform operations on CVAT resources.") +COMMANDS.add_command("function", COMMANDS_FUNCTIONS) COMMANDS.add_command("project", COMMANDS_PROJECTS) COMMANDS.add_command("task", COMMANDS_TASKS) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_functions.py b/cvat-cli/src/cvat_cli/_internal/commands_functions.py new file mode 100644 index 000000000000..76ccc56b05e9 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/commands_functions.py @@ -0,0 +1,138 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import json +import textwrap +from collections.abc import Sequence + +import cvat_sdk.auto_annotation as cvataa +from cvat_sdk import Client + +from .agent import FUNCTION_KIND_DETECTOR, FUNCTION_PROVIDER_NATIVE, run_agent +from .command_base import CommandGroup +from .common import FunctionLoader, configure_function_implementation_arguments + +COMMANDS = CommandGroup(description="Perform operations on CVAT lambda functions.") + + +@COMMANDS.command_class("create-native") +class FunctionCreateNative: + description = textwrap.dedent( + """\ + Create a CVAT function that can be powered by an agent running the given local function. + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "name", + help="a human-readable name for the function", + ) + + configure_function_implementation_arguments(parser) + + def execute( + self, + client: Client, + *, + name: str, + function_loader: FunctionLoader, + ) -> None: + function = function_loader.load() + + remote_function = { + "provider": FUNCTION_PROVIDER_NATIVE, + "name": name, + } + + if isinstance(function.spec, cvataa.DetectionFunctionSpec): + remote_function["kind"] = FUNCTION_KIND_DETECTOR + remote_function["labels_v2"] = [] + + for label_spec in function.spec.labels: + if getattr(label_spec, "sublabels", None): + raise cvataa.BadFunctionError( + f"Function label {label_spec.name!r} has sublabels. This is currently not supported." + ) + + remote_function["labels_v2"].append( + { + "name": label_spec.name, + } + ) + else: + raise cvataa.BadFunctionError( + f"Unsupported function spec type: {type(function.spec).__name__}" + ) + + _, response = client.api_client.call_api( + "/api/functions", + "POST", + body=remote_function, + ) + + remote_function = json.loads(response.data) + + client.logger.info( + "Created function #%d: %s", remote_function["id"], remote_function["name"] + ) + print(remote_function["id"]) + + +@COMMANDS.command_class("delete") +class FunctionDelete: + description = "Delete a list of functions, ignoring those which don't exist." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("function_ids", type=int, help="IDs of functions to delete", nargs="+") + + def execute(self, client: Client, *, function_ids: Sequence[int]) -> None: + for function_id in function_ids: + _, response = client.api_client.call_api( + "/api/functions/{function_id}", + "DELETE", + path_params={"function_id": function_id}, + _check_status=False, + ) + + if 200 <= response.status <= 299: + client.logger.info(f"Function #{function_id} deleted") + elif response.status == 404: + client.logger.warning(f"Function #{function_id} not found") + else: + client.logger.error( + f"Failed to delete function #{function_id}: " + f"{response.msg} (status {response.status})" + ) + + +@COMMANDS.command_class("run-agent") +class FunctionRunAgent: + description = "Process requests for a given native function, indefinitely." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "function_id", + type=int, + help="ID of the function to process requests for", + ) + + configure_function_implementation_arguments(parser) + + parser.add_argument( + "--burst", + action="store_true", + help="process all pending requests and then exit", + ) + + def execute( + self, + client: Client, + *, + function_id: int, + function_loader: FunctionLoader, + burst: bool, + ) -> None: + run_agent(client, function_loader, function_id, burst=burst) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_tasks.py b/cvat-cli/src/cvat_cli/_internal/commands_tasks.py index 8c6782887d97..cbe2139cf457 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands_tasks.py +++ b/cvat-cli/src/cvat_cli/_internal/commands_tasks.py @@ -5,12 +5,9 @@ from __future__ import annotations import argparse -import importlib -import importlib.util import textwrap from collections.abc import Sequence -from pathlib import Path -from typing import Any, Optional +from typing import Optional import cvat_sdk.auto_annotation as cvataa from attr.converters import to_bool @@ -19,13 +16,8 @@ from cvat_sdk.core.proxies.tasks import ResourceType from .command_base import CommandGroup, GenericCommand, GenericDeleteCommand, GenericListCommand -from .parsers import ( - BuildDictAction, - parse_function_parameter, - parse_label_arg, - parse_resource_type, - parse_threshold, -) +from .common import FunctionLoader, configure_function_implementation_arguments +from .parsers import parse_label_arg, parse_resource_type, parse_threshold COMMANDS = CommandGroup(description="Perform operations on CVAT tasks.") @@ -416,30 +408,7 @@ class TaskAutoAnnotate: def configure_parser(self, parser: argparse.ArgumentParser) -> None: parser.add_argument("task_id", type=int, help="task ID") - function_group = parser.add_mutually_exclusive_group(required=True) - - function_group.add_argument( - "--function-module", - metavar="MODULE", - help="qualified name of a module to use as the function", - ) - - function_group.add_argument( - "--function-file", - metavar="PATH", - type=Path, - help="path to a Python source file to use as the function", - ) - - parser.add_argument( - "--function-parameter", - "-p", - metavar="NAME=TYPE:VALUE", - type=parse_function_parameter, - action=BuildDictAction, - dest="function_parameters", - help="parameter for the function", - ) + configure_function_implementation_arguments(parser) parser.add_argument( "--clear-existing", @@ -471,29 +440,13 @@ def execute( client: Client, *, task_id: int, - function_module: Optional[str] = None, - function_file: Optional[Path] = None, - function_parameters: dict[str, Any], + function_loader: FunctionLoader, clear_existing: bool = False, allow_unmatched_labels: bool = False, conf_threshold: Optional[float], conv_mask_to_poly: bool, ) -> None: - if function_module is not None: - function = importlib.import_module(function_module) - elif function_file is not None: - module_spec = importlib.util.spec_from_file_location("__cvat_function__", function_file) - function = importlib.util.module_from_spec(module_spec) - module_spec.loader.exec_module(function) - else: - assert False, "function identification arguments missing" - - if hasattr(function, "create"): - # this is actually a function factory - function = function.create(**function_parameters) - else: - if function_parameters: - raise TypeError("function takes no parameters") + function = function_loader.load() cvataa.annotate_task( client, diff --git a/cvat-cli/src/cvat_cli/_internal/common.py b/cvat-cli/src/cvat_cli/_internal/common.py index 6f37e3d74eaa..e07d85c9b65e 100644 --- a/cvat-cli/src/cvat_cli/_internal/common.py +++ b/cvat-cli/src/cvat_cli/_internal/common.py @@ -5,17 +5,28 @@ import argparse import getpass +import importlib +import importlib.util import logging import os import sys from http.client import HTTPConnection +from pathlib import Path +from typing import Any, Optional +import attrs +import cvat_sdk.auto_annotation as cvataa from cvat_sdk.core.client import Client, Config from ..version import VERSION +from .parsers import BuildDictAction, parse_function_parameter from .utils import popattr +class CriticalError(Exception): + pass + + def get_auth(s): """Parse USER[:PASS] strings and prompt for password if none was supplied.""" @@ -102,3 +113,77 @@ def build_client(parsed_args: argparse.Namespace, logger: logging.Logger) -> Cli client.organization_slug = popattr(parsed_args, "organization") return client + + +def configure_function_implementation_arguments(parser: argparse.ArgumentParser) -> None: + function_group = parser.add_mutually_exclusive_group(required=True) + + function_group.add_argument( + "--function-module", + metavar="MODULE", + help="qualified name of a module to use as the function", + ) + + function_group.add_argument( + "--function-file", + metavar="PATH", + type=Path, + help="path to a Python source file to use as the function", + ) + + parser.add_argument( + "--function-parameter", + "-p", + metavar="NAME=TYPE:VALUE", + type=parse_function_parameter, + action=BuildDictAction, + dest="function_parameters", + help="parameter for the function", + ) + + original_executor = parser.get_default("_executor") + + def execute_with_function_loader( + client, + *, + function_module: Optional[str], + function_file: Optional[Path], + function_parameters: dict[str, Any], + **kwargs, + ): + original_executor( + client, + function_loader=FunctionLoader(function_module, function_file, function_parameters), + **kwargs, + ) + + parser.set_defaults(_executor=execute_with_function_loader) + + +@attrs.frozen +class FunctionLoader: + function_module: Optional[str] + function_file: Optional[Path] + function_parameters: dict[str, Any] + + def __attrs_post_init__(self): + assert self.function_module is not None or self.function_file is not None + + def load(self) -> cvataa.DetectionFunction: + if self.function_module is not None: + function = importlib.import_module(self.function_module) + else: + module_spec = importlib.util.spec_from_file_location( + "__cvat_function__", self.function_file + ) + function = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(function) + + if hasattr(function, "create"): + # this is actually a function factory + function = function.create(**self.function_parameters) + else: + if self.function_parameters: + raise TypeError("function takes no parameters") + + return function diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 5ffdb36f5bee..42e17f93b6b2 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -23,24 +23,62 @@ class BadFunctionError(Exception): """ +@attrs.frozen +class _SublabelNameMapping: + name: str + + +@attrs.frozen +class _LabelNameMapping(_SublabelNameMapping): + sublabels: Optional[Mapping[str, _SublabelNameMapping]] = attrs.field( + kw_only=True, default=None + ) + + def map_sublabel(self, name: str): + if self.sublabels is None: + return _SublabelNameMapping(name) + + return self.sublabels.get(name) + + +@attrs.frozen +class _SpecNameMapping: + labels: Optional[Mapping[str, _LabelNameMapping]] = attrs.field(kw_only=True, default=None) + + def map_label(self, name: str): + if self.labels is None: + return _LabelNameMapping(name) + + return self.labels.get(name) + + class _AnnotationMapper: @attrs.frozen - class _MappedLabel: + class _LabelIdMapping: id: int - sublabel_mapping: Mapping[int, Optional[int]] + sublabels: Mapping[int, Optional[int]] expected_num_elements: int = 0 - _label_mapping: Mapping[int, Optional[_MappedLabel]] + _label_id_mappings: Mapping[int, Optional[_LabelIdMapping]] - def _build_mapped_label( - self, fun_label: models.ILabel, ds_labels_by_name: Mapping[str, models.ILabel] - ) -> Optional[_MappedLabel]: + def _build_label_id_mapping( + self, + fun_label: models.ILabel, + ds_labels_by_name: Mapping[str, models.ILabel], + *, + allow_unmatched_labels: bool, + spec_nm: _SpecNameMapping, + ) -> Optional[_LabelIdMapping]: if getattr(fun_label, "attributes", None): raise BadFunctionError(f"label attributes are currently not supported") - ds_label = ds_labels_by_name.get(fun_label.name) + label_nm = spec_nm.map_label(fun_label.name) + if label_nm is None: + return None + + ds_label = ds_labels_by_name.get(label_nm.name) if ds_label is None: - if not self._allow_unmatched_labels: + if not allow_unmatched_labels: raise BadFunctionError(f"label {fun_label.name!r} is not in dataset") self._logger.info( @@ -71,9 +109,14 @@ def _build_mapped_label( f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})" ) - ds_sl = ds_sublabels_by_name.get(fun_sl.name) + sublabel_nm = label_nm.map_sublabel(fun_sl.name) + if sublabel_nm is None: + sl_map[fun_sl.id] = None + continue + + ds_sl = ds_sublabels_by_name.get(sublabel_nm.name) if not ds_sl: - if not self._allow_unmatched_labels: + if not allow_unmatched_labels: raise BadFunctionError( f"sublabel {fun_sl.name!r} of label {fun_label.name!r} is not in dataset" ) @@ -88,8 +131,8 @@ def _build_mapped_label( sl_map[fun_sl.id] = ds_sl.id - return self._MappedLabel( - ds_label.id, sublabel_mapping=sl_map, expected_num_elements=len(ds_label.sublabels) + return self._LabelIdMapping( + ds_label.id, sublabels=sl_map, expected_num_elements=len(ds_label.sublabels) ) def __init__( @@ -100,26 +143,29 @@ def __init__( *, allow_unmatched_labels: bool, conv_mask_to_poly: bool, + spec_nm: _SpecNameMapping = _SpecNameMapping(), ) -> None: self._logger = logger - self._allow_unmatched_labels = allow_unmatched_labels self._conv_mask_to_poly = conv_mask_to_poly ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} - self._label_mapping = {} + self._label_id_mappings = {} for fun_label in fun_labels: if not hasattr(fun_label, "id"): raise BadFunctionError(f"label {fun_label.name!r} has no ID") - if fun_label.id in self._label_mapping: + if fun_label.id in self._label_id_mappings: raise BadFunctionError( f"label {fun_label.name} has same ID as another label ({fun_label.id})" ) - self._label_mapping[fun_label.id] = self._build_mapped_label( - fun_label, ds_labels_by_name + self._label_id_mappings[fun_label.id] = self._build_label_id_mapping( + fun_label, + ds_labels_by_name, + allow_unmatched_labels=allow_unmatched_labels, + spec_nm=spec_nm, ) def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None: @@ -141,16 +187,16 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: shape.frame = ds_frame try: - mapped_label = self._label_mapping[shape.label_id] + label_id_mapping = self._label_id_mappings[shape.label_id] except KeyError: raise BadFunctionError( f"function output shape with unknown label ID ({shape.label_id})" ) - if not mapped_label: + if not label_id_mapping: continue - shape.label_id = mapped_label.id + shape.label_id = label_id_mapping.id if getattr(shape, "attributes", None): raise BadFunctionError( @@ -184,7 +230,7 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: ) try: - mapped_sl_id = mapped_label.sublabel_mapping[element.label_id] + mapped_sl_id = label_id_mapping.sublabels[element.label_id] except KeyError: raise BadFunctionError( f"function output shape with unknown sublabel ID ({element.label_id})" @@ -204,14 +250,14 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: new_elements.append(element) - if len(new_elements) != mapped_label.expected_num_elements: + if len(new_elements) != label_id_mapping.expected_num_elements: # new_elements could only be shorter than expected, # because the reverse would imply that there are more distinct sublabel IDs # than are actually defined in the dataset. - assert len(new_elements) < mapped_label.expected_num_elements + assert len(new_elements) < label_id_mapping.expected_num_elements raise BadFunctionError( - f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {mapped_label.expected_num_elements})" + f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {label_id_mapping.expected_num_elements})" ) shape.elements[:] = new_elements diff --git a/dev/update_version.py b/dev/update_version.py index bc175aa16dd0..fbe5da9971c0 100755 --- a/dev/update_version.py +++ b/dev/update_version.py @@ -160,8 +160,8 @@ def apply(self, new_version: Version, *, verify_only: bool) -> bool: ), ReplacementRule( "cvat-cli/requirements/base.txt", - re.compile(r"^cvat-sdk~=[\d.]+$", re.M), - lambda v, m: f"cvat-sdk~={v.major}.{v.minor}.{v.patch}", + re.compile(r"^cvat-sdk==[\d.]+$", re.M), + lambda v, m: f"cvat-sdk=={v.major}.{v.minor}.{v.patch}", ), ] diff --git a/site/content/en/docs/api_sdk/cli/_index.md b/site/content/en/docs/api_sdk/cli/_index.md index ffa5be80676b..82bfad795fb0 100644 --- a/site/content/en/docs/api_sdk/cli/_index.md +++ b/site/content/en/docs/api_sdk/cli/_index.md @@ -29,6 +29,11 @@ The following subcommands are supported: - `backup` - back up a task - `auto-annotate` - automatically annotate a task using a local function +- Functions (Enterprise/Cloud only): + - `create-native` - create a function that can be powered by an agent + - `delete` - delete a function + - `run-agent` - process requests for a native function + ## Installation To install an [official release of CVAT CLI](https://pypi.org/project/cvat-cli/), use this command: @@ -316,3 +321,35 @@ see that command's examples for more information. ```bash cvat-cli project ls --json > list_of_projects.json ``` + +## Examples - functions + +**Note**: The functionality described in this section can only be used +with the CVAT Enterprise or CVAT Cloud. + +### Create + +- Create a function that uses a detection model from torchvision + and run an agent for it: + + ``` + cvat-cli function create-native "Faster R-CNN" \ + --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:fasterrcnn_resnet50_fpn_v2 + cvat-cli function run-agent \ + --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:fasterrcnn_resnet50_fpn_v2 + ``` + +These commands accept functions that implement the +{{< ilink "/docs/api_sdk/sdk/auto-annotation" "auto-annotation function interface" >}} +from the SDK, same as the `task auto-annotate` command. +See that command's examples for information on how to implement these functions +and specify them in the command line. + +### Delete + +- Delete functions with IDs 100 and 101: + ``` + cvat-cli function delete 100 101 + ```