Skip to content

Commit

Permalink
Don't fetch existing annotations in `cvat_sdk.auto_annotation.annotat…
Browse files Browse the repository at this point in the history
…e_task`

We don't need existing annotations in order to re-annotate a task, but they
were being fetched anyway, because that's how the underlying `TaskDataset`
class works.

Add an option to `TaskDataset` to disable annotation loading, and use it in
`auto_annotate` to prevent those unnecessary fetches.
  • Loading branch information
SpecLad committed Oct 17, 2023
1 parent e8db2c3 commit 05f1d99
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def annotate_task(
if pbar is None:
pbar = NullProgressReporter()

dataset = TaskDataset(client, task_id)
dataset = TaskDataset(client, task_id, load_annotations=False)

assert isinstance(function.spec, DetectionFunctionSpec)

Expand Down
10 changes: 7 additions & 3 deletions cvat-sdk/cvat_sdk/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT

import abc
from typing import List
from typing import List, Optional

import attrs
import attrs.validators
Expand Down Expand Up @@ -53,8 +53,12 @@ class Sample:
frame_name: str
"""File name of the frame in its task."""

annotations: FrameAnnotations
"""Annotations belonging to the frame."""
annotations: Optional[FrameAnnotations]
"""
Annotations belonging to the frame.
Will be None if the dataset was created without loading annotations.
"""

media: MediaElement
"""Media data of the frame."""
43 changes: 28 additions & 15 deletions cvat-sdk/cvat_sdk/datasets/task_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence
from typing import Sequence, Iterable

import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager, CacheManager
from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError

_NUM_DOWNLOAD_THREADS = 4
Expand Down Expand Up @@ -49,12 +49,17 @@ def __init__(
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
load_annotations: bool = True,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`update_policy` determines when and if the local cache will be updated.
`load_annotations` determines whether annotations will be loaded from
the server. If set to False, the `annotations` field in the samples will
be set to None.
"""

self._logger = client.logger
Expand Down Expand Up @@ -102,6 +107,26 @@ def ensure_chunk(chunk_index):

self._logger.info("All chunks downloaded")

if load_annotations:
self._load_annotations(cache_manager, sorted(active_frame_indexes))
else:
self._frame_annotations = {
frame_index: None for frame_index in sorted(active_frame_indexes)
}

# TODO: tracks?

self._samples = [
Sample(
frame_index=k,
frame_name=data_meta.frames[k].name,
annotations=v,
media=self._TaskMediaElement(self, k),
)
for k, v in self._frame_annotations.items()
]

def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None:
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
Expand All @@ -111,7 +136,7 @@ def ensure_chunk(chunk_index):
)

self._frame_annotations = {
frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes)
frame_index: FrameAnnotations() for frame_index in frame_indexes
}

for tag in annotations.tags:
Expand All @@ -123,18 +148,6 @@ def ensure_chunk(chunk_index):
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)

# TODO: tracks?

self._samples = [
Sample(
frame_index=k,
frame_name=data_meta.frames[k].name,
annotations=v,
media=self._TaskMediaElement(self, k),
)
for k, v in self._frame_annotations.items()
]

@property
def labels(self) -> Sequence[models.ILabel]:
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/python/sdk/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,17 @@ def test_update(self, monkeypatch: pytest.MonkeyPatch):
)

assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id

def test_no_annotations(self):
dataset = cvatds.TaskDataset(self.client, self.task.id, load_annotations=False)

for index, sample in enumerate(dataset.samples):
assert sample.frame_index == index
assert sample.frame_name == self.images[index].name

actual_image = sample.media.load_image()
expected_image = PIL.Image.open(self.images[index])

assert actual_image == expected_image

assert sample.annotations is None

0 comments on commit 05f1d99

Please sign in to comment.