Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix exporting projects with honeypots #8597

Merged
merged 10 commits into from
Oct 31, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Exporting projects with tasks containing honeypots. Honeypots are no longer exported.
(<https://github.com/cvat-ai/cvat/pull/8597>)
75 changes: 51 additions & 24 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(self,
self._db_data: models.Data = db_task.data
self._use_server_track_ids = use_server_track_ids
self._required_frames = included_frames
self._initialized_included_frames: Optional[Set[int]] = None
self._db_subset = db_task.subset

super().__init__(db_task)
Expand Down Expand Up @@ -536,12 +537,14 @@ def shapes(self):
yield self._export_labeled_shape(shape)

def get_included_frames(self):
return set(
i for i in self.rel_range
if not self._is_frame_deleted(i)
and not self._is_frame_excluded(i)
and self._is_frame_required(i)
)
if self._initialized_included_frames is None:
self._initialized_included_frames = set(
i for i in self.rel_range
if not self._is_frame_deleted(i)
and not self._is_frame_excluded(i)
and self._is_frame_required(i)
)
return self._initialized_included_frames

def _is_frame_deleted(self, frame):
return frame in self._deleted_frames
Expand Down Expand Up @@ -1112,7 +1115,10 @@ def _init_frame_info(self):
} for frame in range(task.data.size)})
else:
self._frame_info.update({(task.id, self.rel_frame_id(task.id, db_image.frame)): {
"path": mangle_image_name(db_image.path, defaulted_subset, original_names),
# do not modify honeypot names since they will be excluded from the dataset
# and their quantity should not affect the validation frame name
"path": mangle_image_name(db_image.path, defaulted_subset, original_names) \
if not db_image.is_placeholder else db_image.path,
"id": db_image.id,
"width": db_image.width,
"height": db_image.height,
Expand Down Expand Up @@ -1271,25 +1277,36 @@ def get_frame(task_id: int, idx: int) -> ProjectData.Frame:
return frames[(frame_info["subset"], abs_frame)]

if include_empty:
for ident in sorted(self._frame_info):
if ident not in self._deleted_frames:
get_frame(*ident)
for task_id, frame in sorted(self._frame_info):
if not self._tasks_data.get(task_id):
self.init_task_data(task_id)

task_included_frames = self._tasks_data[task_id].get_included_frames()
if (task_id, frame) not in self._deleted_frames and frame in task_included_frames:
Marishka17 marked this conversation as resolved.
Show resolved Hide resolved
get_frame(task_id, frame)

for task_data in self.task_data:
task: Task = task_data.db_instance

for task in self._db_tasks.values():
anno_manager = AnnotationManager(
self._annotation_irs[task.id], dimension=self._annotation_irs[task.id].dimension
)
task_included_frames = task_data.get_included_frames()

for shape in sorted(
anno_manager.to_shapes(
task.data.size,
included_frames=task_included_frames,
include_outside=False,
use_server_track_ids=self._use_server_track_ids
),
key=lambda shape: shape.get("z_order", 0)
):
if (task.id, shape['frame']) not in self._frame_info or (task.id, shape['frame']) in self._deleted_frames:
if shape['frame'] in task_data.deleted_frames:
continue

assert (task.id, shape['frame']) in self._frame_info

if 'track_id' in shape:
if shape['outside']:
continue
Expand Down Expand Up @@ -1368,23 +1385,33 @@ def soft_attribute_import(self, value: bool):
for task_data in self._tasks_data.values():
task_data.soft_attribute_import = value


def init_task_data(self, task_id: int) -> TaskData:
try:
task = self._db_tasks[task_id]
except KeyError as ex:
raise Exception("There is no such task in the project") from ex

task_data = TaskData(
annotation_ir=self._annotation_irs[task_id],
db_task=task,
host=self._host,
create_callback=self._task_annotations[task_id].create \
if self._task_annotations is not None else None,
)
task_data._MAX_ANNO_SIZE //= len(self._db_tasks)
task_data.soft_attribute_import = self.soft_attribute_import
self._tasks_data[task_id] = task_data

return task_data

Comment on lines +1388 to +1407
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Raise a more specific exception in init_task_data

Instead of raising a generic Exception when a task is not found, consider raising a more specific exception like KeyError or creating a custom exception. This improves error handling and clarity.

Handle potential division by zero in _MAX_ANNO_SIZE calculation

In line 1402, ensure that len(self._db_tasks) is not zero before performing integer division to prevent a possible ZeroDivisionError.

Apply this diff to handle the potential error:

+ if len(self._db_tasks) == 0:
+     raise ValueError("No tasks available to initialize task data.")
  task_data._MAX_ANNO_SIZE //= len(self._db_tasks)

Committable suggestion was skipped due to low confidence.

@property
def task_data(self):
for task_id, task in self._db_tasks.items():
for task_id, _ in self._db_tasks.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for task_id, _ in self._db_tasks.items():
for task_id in self._db_tasks:

Or maybe with .keys()

if task_id in self._tasks_data:
yield self._tasks_data[task_id]
else:
task_data = TaskData(
annotation_ir=self._annotation_irs[task_id],
db_task=task,
host=self._host,
create_callback=self._task_annotations[task_id].create \
if self._task_annotations is not None else None,
)
task_data._MAX_ANNO_SIZE //= len(self._db_tasks)
task_data.soft_attribute_import = self.soft_attribute_import
self._tasks_data[task_id] = task_data
yield task_data
yield self.init_task_data(task_id)

@staticmethod
def _get_filename(path):
Expand Down
4 changes: 2 additions & 2 deletions cvat/apps/dataset_manager/formats/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,8 +1384,8 @@ def dump_media_files(instance_data: CommonData, img_dir: str, project_data: Proj
out_type=FrameOutputType.BUFFER,
)
for frame_id, frame in zip(instance_data.rel_range, frames):
if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \
or frame_id in instance_data.deleted_frames:
# exclude deleted frames and honeypots
if frame_id in instance_data.deleted_frames or frame_id in instance_data._excluded_frames:
continue
frame_name = instance_data.frame_info[frame_id]['path'] if project_data is None \
else project_data.frame_info[(instance_data.db_instance.id, frame_id)]['path']
Expand Down
63 changes: 62 additions & 1 deletion tests/python/rest_api/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@
patch_method,
post_method,
)
from shared.utils.helpers import generate_image_files

from .utils import CollectionSimpleFilterTestBase, export_project_backup, export_project_dataset
from .utils import (
CollectionSimpleFilterTestBase,
create_task,
export_project_backup,
export_project_dataset,
)


@pytest.mark.usefixtures("restore_db_per_class")
Expand Down Expand Up @@ -1038,6 +1044,61 @@ def test_creates_subfolders_for_subsets_on_export(
len([f for f in zip_file.namelist() if f.startswith(folder_prefix)]) > 0
), f"No {folder_prefix} in {zip_file.namelist()}"

def test_export_project_with_honeypots(
self,
admin_user: str,
):
project_spec = {
"name": "Project with honeypots",
"labels": [{"name": "cat"}],
}

with make_api_client(admin_user) as api_client:
project, _ = api_client.projects_api.create(project_spec)

image_files = generate_image_files(3)
image_names = [i.name for i in image_files]

task_params = {
"name": "Task with honeypots",
"segment_size": 1,
"project_id": project.id,
}

data_params = {
"image_quality": 70,
"client_files": image_files,
"sorting_method": "random",
"validation_params": {
"mode": "gt_pool",
"frame_selection_method": "manual",
"frames_per_job_count": 1,
"frames": [image_files[-1].name],
},
}

create_task(admin_user, spec=task_params, data=data_params)

dataset = export_project_dataset(
admin_user, api_version=2, save_images=True, id=project.id, format="COCO 1.0"
)

with zipfile.ZipFile(io.BytesIO(dataset)) as zip_file:
subset_path = "images/default"
assert (
sorted(
[
f[len(subset_path) + 1 :]
for f in zip_file.namelist()
if f.startswith(subset_path)
]
)
== image_names
)
with zip_file.open("annotations/instances_default.json") as anno_file:
annotations = json.load(anno_file)
assert sorted([a["file_name"] for a in annotations["images"]]) == image_names


@pytest.mark.usefixtures("restore_db_per_function")
class TestPatchProjectLabel:
Expand Down
1 change: 1 addition & 0 deletions tests/python/rest_api/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

@pytest.mark.usefixtures("restore_db_per_class")
@pytest.mark.usefixtures("restore_redis_inmem_per_function")
@pytest.mark.usefixtures("restore_redis_ondisk_per_function")
@pytest.mark.timeout(30)
class TestRequestsListFilters(CollectionSimpleFilterTestBase):

Expand Down
Loading