Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing skeleton track project export #8423

Merged
merged 6 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Fixing a problem when project export does not export skeleton tracks
(<https://github.com/cvat-ai/cvat/pull/8423>)
5 changes: 4 additions & 1 deletion cvat/apps/dataset_manager/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,11 @@ def to_shapes(self, end_frame: int, *,

if track.get("elements"):
track_elements = TrackManager(track["elements"], self._dimension)
element_included_frames = set(track_shapes.keys())
if included_frames is not None:
element_included_frames = element_included_frames.intersection(included_frames)
element_shapes = track_elements.to_shapes(end_frame,
included_frames=set(track_shapes.keys()).intersection(included_frames or []),
included_frames=element_included_frames,
include_outside=True, # elements are controlled by the parent shape
use_server_track_ids=use_server_track_ids
)
Expand Down
2 changes: 1 addition & 1 deletion cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ def _export_track(self, track: dict, task_id: int, task_size: int, idx: int):
for i, element in enumerate(track.get("elements", []))]
)

def group_by_frame(self, include_empty=False):
def group_by_frame(self, include_empty: bool = False):
frames: Dict[Tuple[str, int], ProjectData.Frame] = {}
def get_frame(task_id: int, idx: int) -> ProjectData.Frame:
frame_info = self._frame_info[(task_id, idx)]
Expand Down
98 changes: 98 additions & 0 deletions cvat/apps/dataset_manager/tests/assets/annotations.json
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,104 @@
}
]
},
"skeleton track": {
"version": 0,
"tags": [],
"shapes": [],
"tracks": [
{
"frame": 10,
"group": 0,
"source": "file",
"shapes": [
{
"type": "skeleton",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [],
"frame": 10,
"attributes": []
}
],
"attributes": [],
"elements": [
{
"frame": 10,
"group": 0,
"source": "file",
"shapes": [
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
613.99,
326.54
],
"frame": 10,
"attributes": []
},
{
"type": "points",
"occluded": false,
"outside": true,
"z_order": 0,
"rotation": 0,
"points": [
613.99,
326.54
],
"frame": 12,
"attributes": []
}
],
"attributes": [],
"label_id": null
},
{
"frame": 10,
"group": 0,
"source": "file",
"shapes": [
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
613.99,
326.54
],
"frame": 10,
"attributes": []
},
{
"type": "points",
"occluded": false,
"outside": true,
"z_order": 0,
"rotation": 0,
"points": [
613.99,
326.54
],
"frame": 12,
"attributes": []
}
],
"attributes": [],
"label_id": null
}
],
"label_id": null
}
]
},
"ICDAR Localization 1.0": {
"version": 0,
"tags": [],
Expand Down
150 changes: 94 additions & 56 deletions cvat/apps/dataset_manager/tests/test_rest_api_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ def _get_jobs(self, task_id):
)
return values

def _get_tasks(self, project_id):
with ForceLogin(self.admin, self.client):
values = get_paginated_collection(lambda page:
self.client.get("/api/tasks", data={"project_id": project_id, "page": page})
)
return values

def _get_request(self, path, user):
with ForceLogin(user, self.client):
response = self.client.get(path)
Expand Down Expand Up @@ -345,6 +352,13 @@ def _delete_project(self, project_id, user):
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
return response

@staticmethod
def _save_file_from_response(response, file_name):
if response.status_code == status.HTTP_200_OK:
content = b"".join(response.streaming_content)
with open(file_name, "wb") as f:
f.write(content)


class TaskDumpUploadTest(_DbTestBase):
def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self):
Expand Down Expand Up @@ -415,10 +429,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self):
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata['code'])
if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, file_zip_name)
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])

# Upload annotations with objects type is shape
Expand Down Expand Up @@ -526,10 +537,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self):
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata['code'])
if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, file_zip_name)
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])
# Upload annotations with objects type is track
for upload_format in upload_formats:
Expand Down Expand Up @@ -616,10 +624,7 @@ def test_api_v2_dump_tag_annotations(self):
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata['code'])
if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, file_zip_name)
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])

def test_api_v2_dump_and_upload_annotations_with_objects_are_different_images(self):
Expand Down Expand Up @@ -859,10 +864,7 @@ def test_api_v2_export_dataset(self):
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["code"])
if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, file_zip_name)
self.assertEqual(response.status_code, edata['code'])
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])

Expand Down Expand Up @@ -1685,9 +1687,7 @@ def patched_osp_exists(path: str):
response = self._get_request_with_data(download_url, download_params, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)

content = BytesIO(b"".join(response.streaming_content))
with open(osp.join(temp_dir, "export.zip"), "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, osp.join(temp_dir, "export.zip"))

mock_osp_exists.assert_called()

Expand Down Expand Up @@ -2046,6 +2046,22 @@ def test_cleanup_can_be_called_with_old_signature_and_values(self):


class ProjectDumpUpload(_DbTestBase):
def _get_download_project_dataset_response(self, url, user, dump_format_name, edata):
data = {
"format": dump_format_name,
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["accept code"])

response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["create code"])

data = {
"format": dump_format_name,
"action": "download",
}
return self._get_request_with_data(url, data, user)

def test_api_v2_export_import_dataset(self):
test_name = self._testMethodName
dump_formats = dm.views.get_export_formats()
Expand Down Expand Up @@ -2095,28 +2111,9 @@ def test_api_v2_export_import_dataset(self):

user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
data = {
"format": dump_format_name,
}

response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["accept code"])

response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["create code"])

data = {
"format": dump_format_name,
"action": "download",
}
response = self._get_request_with_data(url, data, user)
response = self._get_download_project_dataset_response(url, user, dump_format_name, edata)
self.assertEqual(response.status_code, edata["code"])

if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())

self._save_file_from_response(response, file_zip_name)
self.assertEqual(response.status_code, edata['code'])
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])

Expand Down Expand Up @@ -2177,22 +2174,63 @@ def test_api_v2_export_annotations(self):

user_name = edata['name']
file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip')
data = {
"format": dump_format_name,
}
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["accept code"])
response = self._get_request_with_data(url, data, user)
self.assertEqual(response.status_code, edata["create code"])
data = {
"format": dump_format_name,
"action": "download",
}
response = self._get_request_with_data(url, data, user)
response = self._get_download_project_dataset_response(url, user, dump_format_name, edata)
self.assertEqual(response.status_code, edata["code"])
if response.status_code == status.HTTP_200_OK:
content = BytesIO(b"".join(response.streaming_content))
with open(file_zip_name, "wb") as f:
f.write(content.getvalue())
self._save_file_from_response(response, file_zip_name)
self.assertEqual(response.status_code, edata['code'])
self.assertEqual(osp.exists(file_zip_name), edata['file_exists'])

def test_api_v2_dump_upload_annotations_with_objects_type_is_track(self):
test_name = self._testMethodName
upload_format_name = dump_format_name = "COCO Keypoints 1.0"
user = self.admin
edata = {'name': 'admin', 'code': status.HTTP_200_OK, 'create code': status.HTTP_201_CREATED,
'accept code': status.HTTP_202_ACCEPTED, 'file_exists': True, 'annotation_loaded': True}

with TestDir() as test_dir:
# Dump annotations with objects type is track
# create task with annotations
project_dict = copy.deepcopy(projects['main'])
task_dict = copy.deepcopy(tasks[dump_format_name])
project_dict["labels"] = task_dict["labels"]
del task_dict["labels"]
for label in project_dict["labels"]:
label["attributes"] = [{
"name": "is_crowd",
"mutable": False,
"input_type": "checkbox",
"default_value": "false",
"values": ["false", "true"]
}]
project = self._create_project(project_dict)
pid = project['id']
video = self._generate_task_videos(1)
task_dict['project_id'] = pid
task = self._create_task(task_dict, video)
task_id = task["id"]
self._create_annotations(task, "skeleton track", "default")
# dump annotations
url = self._generate_url_dump_project_dataset(project['id'], dump_format_name)

self._clear_rq_jobs() # clean up from previous tests and iterations

file_zip_name = osp.join(test_dir, f'{test_name}_{dump_format_name}.zip')
response = self._get_download_project_dataset_response(url, user, dump_format_name, edata)
self.assertEqual(response.status_code, edata['code'])
self._save_file_from_response(response, file_zip_name)
self.assertEqual(osp.exists(file_zip_name), True)

data_from_task_before_upload = self._get_data_from_task(task_id, True)

# Upload annotations with objects type is track
project = self._create_project(project_dict)
url = self._generate_url_upload_project_dataset(project["id"], upload_format_name)

with open(file_zip_name, 'rb') as binary_file:
response = self._post_request_with_data(url, {"dataset_file": binary_file}, user)
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)

# equals annotations
new_task = self._get_tasks(project["id"])[0]
data_from_task_after_upload = self._get_data_from_task(new_task["id"], True)
compare_datasets(data_from_task_before_upload, data_from_task_after_upload)
Loading