From 31f78d5a4bead9da6a0f3026746ee28def9e9a5d Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 29 Jun 2023 19:59:04 +0300 Subject: [PATCH] Fix file matching in annotation import for multiple dots in filenames (#6350) ### Motivation and context Fixes https://github.com/opencv/cvat/issues/6319 - Fixed invalid dataset root search, leading to invalid file matching - Restored detailed dataset import error messages - Added tests ### How has this been tested? ### Checklist - [x] I submit my changes into the `develop` branch - [ ] I have added a description of my changes into the [CHANGELOG](https://github.com/opencv/cvat/blob/develop/CHANGELOG.md) file - [ ] I have updated the documentation accordingly - [ ] I have added tests to cover my changes - [ ] I have linked related issues (see [GitHub docs]( https://help.github.com/en/github/managing-your-work-on-github/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword)) - [ ] I have increased versions of npm packages if it is necessary ([cvat-canvas](https://github.com/opencv/cvat/tree/develop/cvat-canvas#versioning), [cvat-core](https://github.com/opencv/cvat/tree/develop/cvat-core#versioning), [cvat-data](https://github.com/opencv/cvat/tree/develop/cvat-data#versioning) and [cvat-ui](https://github.com/opencv/cvat/tree/develop/cvat-ui#versioning)) ### License - [x] I submit _my code changes_ under the same [MIT License]( https://github.com/opencv/cvat/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. --- CHANGELOG.md | 1 + cvat/apps/dataset_manager/bindings.py | 60 +++++++--- cvat/apps/engine/views.py | 11 +- tests/python/rest_api/test_tasks.py | 151 ++++++++++++++++++++++++++ tests/python/shared/utils/helpers.py | 13 ++- 5 files changed, 210 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4d9c20a2eeb..655121c979f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - \[API\] Invalid schema for the owner field in several endpoints () - \[SDK\] Loading tasks that have been cached with the PyTorch adapter () +- The problem with importing annotations if dataset has extra dots in filenames () ### Security - More comprehensive SSRF mitigations were implemented. diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index b130e4d08c8c..4eefaa115e62 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -212,7 +212,7 @@ def __init__(self, self._create_callback = create_callback self._MAX_ANNO_SIZE = 30000 self._frame_info = {} - self._frame_mapping = {} + self._frame_mapping: Dict[str, int] = {} self._frame_step = db_task.data.get_frame_step() self._db_data = db_task.data self._use_server_track_ids = use_server_track_ids @@ -613,28 +613,37 @@ def __len__(self): raise NotImplementedError() @staticmethod - def _get_filename(path): + def _get_filename(path: str) -> str: return osp.splitext(path)[0] - def match_frame(self, path, root_hint=None, path_has_ext=True): + def match_frame(self, + path: str, root_hint: Optional[str] = None, *, path_has_ext: bool = True + ) -> Optional[int]: if path_has_ext: path = self._get_filename(path) + match = self._frame_mapping.get(path) + if not match and root_hint and not path.startswith(root_hint): path = osp.join(root_hint, path) match = self._frame_mapping.get(path) + return match - def match_frame_fuzzy(self, path): + def match_frame_fuzzy(self, path: str, *, path_has_ext: bool = True) -> Optional[int]: # Preconditions: # - The input dataset is full, i.e. all items present. Partial dataset # matching can't be correct for all input cases. # - path is the longest path of input dataset in terms of path parts - path = Path(self._get_filename(path)).parts + if path_has_ext: + path = self._get_filename(path) + + path = Path(path).parts for p, v in self._frame_mapping.items(): if Path(p).parts[-len(path):] == path: # endswith() for paths return v + return None class JobData(CommonData): @@ -1254,20 +1263,30 @@ def task_data(self): def _get_filename(path): return osp.splitext(path)[0] - def match_frame(self, path: str, subset: str=dm.DEFAULT_SUBSET_NAME, root_hint: str=None, path_has_ext: bool=True): + def match_frame(self, + path: str, subset: str = dm.DEFAULT_SUBSET_NAME, + root_hint: str = None, path_has_ext: bool = True + ) -> Optional[int]: if path_has_ext: path = self._get_filename(path) + match_task, match_frame = self._frame_mapping.get((subset, path), (None, None)) + if not match_frame and root_hint and not path.startswith(root_hint): path = osp.join(root_hint, path) match_task, match_frame = self._frame_mapping.get((subset, path), (None, None)) + return match_task, match_frame - def match_frame_fuzzy(self, path): - path = Path(self._get_filename(path)).parts + def match_frame_fuzzy(self, path: str, *, path_has_ext: bool = True) -> Optional[int]: + if path_has_ext: + path = self._get_filename(path) + + path = Path(path).parts for (_subset, _path), (_tid, frame_number) in self._frame_mapping.items(): if Path(_path).parts[-len(path):] == path : return frame_number + return None def split_dataset(self, dataset: dm.Dataset): @@ -1814,7 +1833,11 @@ def convert_cvat_anno_to_dm( return converter.convert() -def match_dm_item(item, instance_data, root_hint=None): +def match_dm_item( + item: dm.DatasetItem, + instance_data: Union[ProjectData, CommonData], + root_hint: Optional[str] = None +) -> int: is_video = instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation' frame_number = None @@ -1832,20 +1855,23 @@ def match_dm_item(item, instance_data, root_hint=None): "'%s' with any task frame" % item.id) return frame_number -def find_dataset_root(dm_dataset, instance_data: Union[ProjectData, CommonData]): - longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts), - default=None) - if longest_path is None: +def find_dataset_root( + dm_dataset: dm.IDataset, instance_data: Union[ProjectData, CommonData] +) -> Optional[str]: + longest_path_item = max(dm_dataset, key=lambda item: len(Path(item.id).parts), default=None) + if longest_path_item is None: return None - longest_path = longest_path.id + longest_path = longest_path_item.id - longest_match = instance_data.match_frame_fuzzy(longest_path) - if longest_match is None: + matched_frame_number = instance_data.match_frame_fuzzy(longest_path, path_has_ext=False) + if matched_frame_number is None: return None - longest_match = osp.dirname(instance_data.frame_info[longest_match]['path']) + + longest_match = osp.dirname(instance_data.frame_info[matched_frame_number]['path']) prefix = longest_match[:-len(osp.dirname(longest_path)) or None] if prefix.endswith('/'): prefix = prefix[:-1] + return prefix def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectData, CommonData]): diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index d4ae6316ce98..225fd729cd27 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -2917,12 +2917,11 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name, elif rq_job.is_failed or \ rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed: exc_info = process_failed_job(rq_job) - # RQ adds a prefix with exception class name - import_error_prefix = '{}.{}'.format( - CvatImportError.__module__, CvatImportError.__name__) - if import_error_prefix in exc_info: - return Response(data="The annotations that were uploaded are not correct", - status=status.HTTP_400_BAD_REQUEST) + + import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:' + if exc_info.startswith("Traceback") and import_error_prefix in exc_info: + exc_message = exc_info.split(import_error_prefix)[-1].strip() + return Response(data=exc_message, status=status.HTTP_400_BAD_REQUEST) else: return Response(data=exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index d03b857ca601..14ce2f5fbf6f 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -2172,3 +2172,154 @@ def test_check_import_cache_after_previous_interrupted_upload(self, tasks_with_s if not number_of_files: break assert not number_of_files + + +class TestImportWithComplexFilenames: + @staticmethod + def _make_client() -> Client: + return Client(BASE_URL, config=Config(status_check_period=0.01)) + + @pytest.fixture( + autouse=True, + scope="class", + # classmethod way may not work in some versions + # https://github.com/opencv/cvat/actions/runs/5336023573/jobs/9670573955?pr=6350 + name="TestImportWithComplexFilenames.setup_class", + ) + @classmethod + def setup_class( + cls, restore_db_per_class, tmp_path_factory: pytest.TempPathFactory, admin_user: str + ): + cls.tmp_dir = tmp_path_factory.mktemp(cls.__class__.__name__) + cls.client = cls._make_client() + cls.user = admin_user + cls.format_name = "PASCAL VOC 1.1" + + with cls.client: + cls.client.login((cls.user, USER_PASS)) + + cls._init_tasks() + + @classmethod + def _create_task_with_annotations(cls, filenames: List[str]): + images = generate_image_files(len(filenames), filenames=filenames) + + source_archive_path = cls.tmp_dir / "source_data.zip" + with zipfile.ZipFile(source_archive_path, "w") as zip_file: + for image in images: + zip_file.writestr(image.name, image.getvalue()) + + task = cls.client.tasks.create_from_data( + { + "name": "test_images_with_dots", + "labels": [{"name": "cat"}, {"name": "dog"}], + }, + resources=[source_archive_path], + ) + + labels = task.get_labels() + task.set_annotations( + models.LabeledDataRequest( + shapes=[ + models.LabeledShapeRequest( + frame=frame_id, + label_id=labels[0].id, + type="rectangle", + points=[1, 1, 2, 2], + ) + for frame_id in range(len(filenames)) + ], + ) + ) + + return task + + @classmethod + def _init_tasks(cls): + cls.flat_filenames = [ + "filename0.jpg", + "file.name1.jpg", + "fi.le.na.me.2.jpg", + ".filename3.jpg", + "..filename..4.jpg", + "..filename..5.png..jpg", + ] + + cls.nested_filenames = [ + f"{prefix}/{fn}" + for prefix, fn in zip( + [ + "ab/cd", + "ab/cd", + "ab", + "ab", + "cd/ef", + "cd/ef", + "cd", + "", + ], + cls.flat_filenames, + ) + ] + + cls.data = {} + for (kind, filenames), prefix in product( + [("flat", cls.flat_filenames), ("nested", cls.nested_filenames)], ["", "pre/fix"] + ): + key = kind + if prefix: + key += "_prefixed" + + task = cls._create_task_with_annotations( + [f"{prefix}/{fn}" if prefix else fn for fn in filenames] + ) + + dataset_file = cls.tmp_dir / f"{key}_dataset.zip" + task.export_dataset(cls.format_name, dataset_file, include_images=False) + + cls.data[key] = (task, dataset_file) + + @pytest.mark.parametrize( + "task_kind, annotation_kind, expect_success", + [ + ("flat", "flat", True), + ("flat", "flat_prefixed", False), + ("flat", "nested", False), + ("flat", "nested_prefixed", False), + ("flat_prefixed", "flat", True), # allow this for better UX + ("flat_prefixed", "flat_prefixed", True), + ("flat_prefixed", "nested", False), + ("flat_prefixed", "nested_prefixed", False), + ("nested", "flat", False), + ("nested", "flat_prefixed", False), + ("nested", "nested", True), + ("nested", "nested_prefixed", False), + ("nested_prefixed", "flat", False), + ("nested_prefixed", "flat_prefixed", False), + ("nested_prefixed", "nested", True), # allow this for better UX + ("nested_prefixed", "nested_prefixed", True), + ], + ) + def test_import_annotations(self, task_kind, annotation_kind, expect_success): + # Tests for regressions about https://github.com/opencv/cvat/issues/6319 + # + # X annotations must be importable to X prefixed cases + # with and without dots in filenames. + # + # Nested structures can potentially be matched to flat ones and vise-versa, + # but it's not supported now, as it may lead to some errors in matching. + + task: Task = self.data[task_kind][0] + dataset_file = self.data[annotation_kind][1] + + if expect_success: + task.import_annotations(self.format_name, dataset_file) + + assert set(s.frame for s in task.get_annotations().shapes) == set( + range(len(self.flat_filenames)) + ) + else: + with pytest.raises(exceptions.ApiException) as capture: + task.import_annotations(self.format_name, dataset_file) + + assert b"Could not match item id" in capture.value.body diff --git a/tests/python/shared/utils/helpers.py b/tests/python/shared/utils/helpers.py index d8987a7ba496..289d24e7966f 100644 --- a/tests/python/shared/utils/helpers.py +++ b/tests/python/shared/utils/helpers.py @@ -4,7 +4,7 @@ import subprocess from io import BytesIO -from typing import List +from typing import List, Optional from PIL import Image @@ -21,11 +21,18 @@ def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)): return f -def generate_image_files(count, prefixes=None) -> List[BytesIO]: +def generate_image_files( + count, prefixes=None, *, filenames: Optional[List[str]] = None +) -> List[BytesIO]: + assert not (prefixes and filenames), "prefixes cannot be used together with filenames" + assert not prefixes or len(prefixes) == count + assert not filenames or len(filenames) == count + images = [] for i in range(count): prefix = prefixes[i] if prefixes else "" - image = generate_image_file(f"{prefix}{i}.jpeg", color=(i, i, i)) + filename = f"{prefix}{i}.jpeg" if not filenames else filenames[i] + image = generate_image_file(filename, color=(i, i, i)) images.append(image) return images