From 05f1d99153119753cb713fd2e2e3fb0eac64ce4f Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Tue, 17 Oct 2023 19:22:09 +0300 Subject: [PATCH] Don't fetch existing annotations in `cvat_sdk.auto_annotation.annotate_task` We don't need existing annotations in order to re-annotate a task, but they were being fetched anyway, because that's how the underlying `TaskDataset` class works. Add an option to `TaskDataset` to disable annotation loading, and use it in `auto_annotate` to prevent those unnecessary fetches. --- cvat-sdk/cvat_sdk/auto_annotation/driver.py | 2 +- cvat-sdk/cvat_sdk/datasets/common.py | 10 +++-- cvat-sdk/cvat_sdk/datasets/task_dataset.py | 43 ++++++++++++++------- tests/python/sdk/test_datasets.py | 14 +++++++ 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 8c1c71b46e8b..d6294f44f8f6 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -268,7 +268,7 @@ def annotate_task( if pbar is None: pbar = NullProgressReporter() - dataset = TaskDataset(client, task_id) + dataset = TaskDataset(client, task_id, load_annotations=False) assert isinstance(function.spec, DetectionFunctionSpec) diff --git a/cvat-sdk/cvat_sdk/datasets/common.py b/cvat-sdk/cvat_sdk/datasets/common.py index c621a2d2ed33..b407c490802c 100644 --- a/cvat-sdk/cvat_sdk/datasets/common.py +++ b/cvat-sdk/cvat_sdk/datasets/common.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT import abc -from typing import List +from typing import List, Optional import attrs import attrs.validators @@ -53,8 +53,12 @@ class Sample: frame_name: str """File name of the frame in its task.""" - annotations: FrameAnnotations - """Annotations belonging to the frame.""" + annotations: Optional[FrameAnnotations] + """ + Annotations belonging to the frame. + + Will be None if the dataset was created without loading annotations. + """ media: MediaElement """Media data of the frame.""" diff --git a/cvat-sdk/cvat_sdk/datasets/task_dataset.py b/cvat-sdk/cvat_sdk/datasets/task_dataset.py index 111528d43715..aa4f05f574e0 100644 --- a/cvat-sdk/cvat_sdk/datasets/task_dataset.py +++ b/cvat-sdk/cvat_sdk/datasets/task_dataset.py @@ -6,14 +6,14 @@ import zipfile from concurrent.futures import ThreadPoolExecutor -from typing import Sequence +from typing import Sequence, Iterable import PIL.Image import cvat_sdk.core import cvat_sdk.core.exceptions import cvat_sdk.models as models -from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager +from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager, CacheManager from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError _NUM_DOWNLOAD_THREADS = 4 @@ -49,12 +49,17 @@ def __init__( task_id: int, *, update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE, + load_annotations: bool = True, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the server that `client` is connected to. `update_policy` determines when and if the local cache will be updated. + + `load_annotations` determines whether annotations will be loaded from + the server. If set to False, the `annotations` field in the samples will + be set to None. """ self._logger = client.logger @@ -102,6 +107,26 @@ def ensure_chunk(chunk_index): self._logger.info("All chunks downloaded") + if load_annotations: + self._load_annotations(cache_manager, sorted(active_frame_indexes)) + else: + self._frame_annotations = { + frame_index: None for frame_index in sorted(active_frame_indexes) + } + + # TODO: tracks? + + self._samples = [ + Sample( + frame_index=k, + frame_name=data_meta.frames[k].name, + annotations=v, + media=self._TaskMediaElement(self, k), + ) + for k, v in self._frame_annotations.items() + ] + + def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None: annotations = cache_manager.ensure_task_model( self._task.id, "annotations.json", @@ -111,7 +136,7 @@ def ensure_chunk(chunk_index): ) self._frame_annotations = { - frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes) + frame_index: FrameAnnotations() for frame_index in frame_indexes } for tag in annotations.tags: @@ -123,18 +148,6 @@ def ensure_chunk(chunk_index): if shape.frame in self._frame_annotations: self._frame_annotations[shape.frame].shapes.append(shape) - # TODO: tracks? - - self._samples = [ - Sample( - frame_index=k, - frame_name=data_meta.frames[k].name, - annotations=v, - media=self._TaskMediaElement(self, k), - ) - for k, v in self._frame_annotations.items() - ] - @property def labels(self) -> Sequence[models.ILabel]: """ diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py index 35b2339ec67e..d5fbc0957eb7 100644 --- a/tests/python/sdk/test_datasets.py +++ b/tests/python/sdk/test_datasets.py @@ -206,3 +206,17 @@ def test_update(self, monkeypatch: pytest.MonkeyPatch): ) assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id + + def test_no_annotations(self): + dataset = cvatds.TaskDataset(self.client, self.task.id, load_annotations=False) + + for index, sample in enumerate(dataset.samples): + assert sample.frame_index == index + assert sample.frame_name == self.images[index].name + + actual_image = sample.media.load_image() + expected_image = PIL.Image.open(self.images[index]) + + assert actual_image == expected_image + + assert sample.annotations is None