diff --git a/requirements.txt b/requirements.txt index 8bc75c2..304549e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ TextGrid==1.5 tqdm==4.40.2 ttkthemes==2.4.0 xparser==0.0.4 +python-magic==0.4.15 diff --git a/ultratrace2/model/files/__init__.py b/ultratrace2/model/files/__init__.py index bf613f7..e5a9908 100644 --- a/ultratrace2/model/files/__init__.py +++ b/ultratrace2/model/files/__init__.py @@ -19,7 +19,7 @@ from .loaders import FLACLoader __register( - [".flac"], ["audio/flac"], FLACLoader, + [".flac"], ["audio/flac", "audio/x-flac"], FLACLoader, ) except ImportError as e: logger.warning(e) @@ -37,7 +37,9 @@ from .loaders import MP3Loader __register( - [".mp3"], ["audio/mp3", "audio/MPA", "audio/mpa-robust"], MP3Loader, + [".mp3"], + ["audio/mp3", "audio/mpeg", "audio/MPA", "audio/mpa-robust"], + MP3Loader, ) except ImportError as e: logger.warning(e) diff --git a/ultratrace2/model/files/bundle.py b/ultratrace2/model/files/bundle.py index 7bf73ae..49870cc 100644 --- a/ultratrace2/model/files/bundle.py +++ b/ultratrace2/model/files/bundle.py @@ -35,16 +35,25 @@ def has_impl(self) -> bool: for f in [self.alignment_file, self.image_set_file, self.sound_file] ) + def get_alignment_file(self) -> Optional[AlignmentFileLoader]: + return self.alignment_file + def set_alignment_file(self, alignment_file: AlignmentFileLoader) -> None: if self.alignment_file is not None: logger.warning("Overwriting existing alignment file") self.alignment_file = alignment_file + def get_image_set_file(self) -> Optional[ImageSetFileLoader]: + return self.image_set_file + def set_image_set_file(self, image_set_file: ImageSetFileLoader) -> None: if self.image_set_file is not None: logger.warning("Overwriting existing image-set file") self.image_set_file = image_set_file + def get_sound_file(self) -> Optional[SoundFileLoader]: + return self.sound_file + def set_sound_file(self, sound_file: SoundFileLoader) -> None: if self.sound_file is not None: logger.warning("Overwriting existing sound file") @@ -53,6 +62,14 @@ def set_sound_file(self, sound_file: SoundFileLoader) -> None: def __repr__(self): return f'Bundle("{self.name}",{self.alignment_file},{self.image_set_file},{self.sound_file})' + def __eq__(self, other): + return ( + self.name == other.name + and self.alignment_file == other.alignment_file + and self.image_set_file == other.image_set_file + and self.sound_file == other.sound_file + ) + class FileBundleList: @@ -115,7 +132,7 @@ def build_from_dir( continue if name not in bundles: - bundles[name] = FileBundle(filepath) + bundles[name] = FileBundle(name) try: loaded_file = file_loader.from_file(filepath) diff --git a/ultratrace2/model/files/loaders/base.py b/ultratrace2/model/files/loaders/base.py index 99b4b6e..9822f8a 100644 --- a/ultratrace2/model/files/loaders/base.py +++ b/ultratrace2/model/files/loaders/base.py @@ -36,6 +36,13 @@ def from_file(cls: Type[Self], path: str) -> Self: NB: If this concrete method fails to load the data at the given path, then it should throw a `FileLoadError`.""" + def __eq__(self, other): + return self.get_path() == other.get_path() and type(self) == type(other) + + @staticmethod + def get_priority() -> int: + return 0 + class IntervalBase(Protocol): def get_start(self) -> float: diff --git a/ultratrace2/model/files/registry.py b/ultratrace2/model/files/registry.py index 4766fb3..34fc9df 100644 --- a/ultratrace2/model/files/registry.py +++ b/ultratrace2/model/files/registry.py @@ -1,19 +1,28 @@ -import mimetypes +import magic # type: ignore import os -from collections import defaultdict -from typing import DefaultDict, Optional, Sequence, Set, Type +from typing import Dict, Mapping, Optional, Sequence, Set, Type, Union + +from .loaders.base import ( + FileLoaderBase, + AlignmentFileLoader, + ImageSetFileLoader, + SoundFileLoader, +) -from .loaders.base import FileLoaderBase +AbstractLoader = Union[ + Type[AlignmentFileLoader], Type[ImageSetFileLoader], Type[SoundFileLoader] +] # global maps -__extension_to_loaders_map: DefaultDict[str, Set[Type[FileLoaderBase]]] = defaultdict( - set -) -__mime_type_to_loaders_map: DefaultDict[str, Set[Type[FileLoaderBase]]] = defaultdict( - set -) +__extension_to_loaders_map: Dict[str, Type[FileLoaderBase]] = {} +__mime_type_to_loaders_map: Dict[str, Type[FileLoaderBase]] = {} +__loader_priorities_map: Mapping[AbstractLoader, Set[int]] = { + AlignmentFileLoader: set(), + ImageSetFileLoader: set(), + SoundFileLoader: set(), +} def register_loader_for_extensions_and_mime_types( @@ -30,32 +39,45 @@ def register_loader_for_extensions_and_mime_types( loader_cls: A file loader which knows how to load files with the given file extensions and MIME types """ + loader_type: AbstractLoader + if issubclass(loader_cls, AlignmentFileLoader): + loader_type = AlignmentFileLoader + elif issubclass(loader_cls, ImageSetFileLoader): + loader_type = ImageSetFileLoader + elif issubclass(loader_cls, SoundFileLoader): + loader_type = SoundFileLoader + else: + raise ValueError(f"Invalid loader class: {loader_cls.__name__}") + priority = loader_cls.get_priority() + if priority in __loader_priorities_map[loader_type]: + raise ValueError( + f"Cannot have duplicate priorities for loader type {loader_type.__name__}" + ) + for extension in extensions: - __extension_to_loaders_map[extension].add(loader_cls) + __extension_to_loaders_map[extension] = loader_cls for mime_type in mime_types: - __mime_type_to_loaders_map[mime_type].add(loader_cls) + __mime_type_to_loaders_map[mime_type] = loader_cls def get_loader_for(path: str) -> Optional[Type[FileLoaderBase]]: _, extension = os.path.splitext(path.lower()) - mime_type, _ = mimetypes.guess_type(path) + mime_type = magic.Magic(mime=True).from_file(path) if mime_type is None: # Early return since we can't possibly match anymore return None - loader_clses_by_extension = __extension_to_loaders_map[extension] - loader_clses_by_mime_type = __mime_type_to_loaders_map[mime_type] + loader_cls_by_extension = __extension_to_loaders_map.get(extension, None) + loader_cls_by_mime_type = __mime_type_to_loaders_map.get(mime_type, None) # NB: Use set-intersection (could potentially use set-union instead). - loader_clses = loader_clses_by_extension & loader_clses_by_mime_type - - if len(loader_clses) == 0: + if loader_cls_by_extension is None or loader_cls_by_mime_type is None: return None - elif len(loader_clses) == 1: - return loader_clses.pop() + if loader_cls_by_extension == loader_cls_by_mime_type: + return loader_cls_by_extension else: - raise NotImplementedError( - f"Found multiple Loaders for path '{path}': {','.join(map(str, loader_clses))}" + raise ValueError( + f"Warning: got {loader_cls_by_extension.__name__} for {extension} and {loader_cls_by_mime_type} for {mime_type}" ) diff --git a/ultratrace2/model/files/tests/test_bundle.py b/ultratrace2/model/files/tests/test_bundle.py index fb0e558..6881191 100644 --- a/ultratrace2/model/files/tests/test_bundle.py +++ b/ultratrace2/model/files/tests/test_bundle.py @@ -1,6 +1,16 @@ +from typing import Dict, Mapping, Sequence, Tuple, Type + +import os import pytest from ..bundle import FileBundle, FileBundleList +from ..loaders import DICOMLoader, FLACLoader, MP3Loader, TextGridLoader, WAVLoader +from ..loaders.base import ( + FileLoaderBase, + AlignmentFileLoader, + ImageSetFileLoader, + SoundFileLoader, +) @pytest.mark.parametrize( @@ -16,12 +26,84 @@ dict(alignment_file=None, image_set_file=None, sound_file=None), ], ) -def test_empty_file_bundle_constructor(kwargs) -> None: +def test_empty_file_bundle_constructor(kwargs: Mapping[str, None]) -> None: fb = FileBundle("test", **kwargs) assert not fb.has_impl() assert str(fb) == 'Bundle("test",None,None,None)' +@pytest.mark.parametrize( + "kwargs", + [ + dict( + alignment_file=TextGridLoader.from_file( + "./test-data/example-bundles/ex000/file00.TextGrid" + ) + ), + dict( + image_set_file=DICOMLoader.from_file( + "./test-data/example-bundles/ex004/file00.dicom" + ) + ), + dict( + sound_file=MP3Loader.from_file( + "./test-data/example-bundles/ex002/file00.mp3" + ) + ), + dict( + alignment_file=TextGridLoader.from_file( + "./test-data/example-bundles/ex000/file00.TextGrid" + ), + image_set_file=DICOMLoader.from_file( + "./test-data/example-bundles/ex004/file00.dicom" + ), + ), + dict( + alignment_file=TextGridLoader.from_file( + "./test-data/example-bundles/ex000/file00.TextGrid" + ), + sound_file=MP3Loader.from_file( + "./test-data/example-bundles/ex002/file00.mp3" + ), + ), + dict( + image_set_file=DICOMLoader.from_file( + "./test-data/example-bundles/ex004/file00.dicom" + ), + sound_file=MP3Loader.from_file( + "./test-data/example-bundles/ex002/file00.mp3" + ), + ), + dict( + alignment_file=TextGridLoader.from_file( + "./test-data/example-bundles/ex000/file00.TextGrid" + ), + image_set_file=DICOMLoader.from_file( + "./test-data/example-bundles/ex004/file00.dicom" + ), + sound_file=MP3Loader.from_file( + "./test-data/example-bundles/ex002/file00.mp3" + ), + ), + ], +) +def test_file_bundle_constructor(kwargs: Mapping[str, FileLoaderBase]) -> None: + fb = FileBundle("test", **kwargs) + assert fb.has_impl() + if "alignment_file" in kwargs: + alignment_file = fb.get_alignment_file() + assert isinstance(alignment_file, AlignmentFileLoader) + assert alignment_file.get_path() == kwargs["alignment_file"].get_path() + if "image_set_file" in kwargs: + image_set_file = fb.get_image_set_file() + assert isinstance(image_set_file, ImageSetFileLoader) + assert image_set_file.get_path() == kwargs["image_set_file"].get_path() + if "sound_file" in kwargs: + sound_file = fb.get_sound_file() + assert isinstance(sound_file, SoundFileLoader) + assert sound_file.get_path() == kwargs["sound_file"].get_path() + + def test_build_from_nonexistent_dir(mocker) -> None: mock_file_bundle_list_constructor = mocker.patch( "ultratrace2.model.files.bundle.FileBundleList.__init__", return_value=None, @@ -32,11 +114,181 @@ def test_build_from_nonexistent_dir(mocker) -> None: @pytest.mark.parametrize( - "source_dir,bundle_map", [("./test-data/example-bundles/ex000", {})], + "source_dir,expected_file_map,should_emit_warning", + [ + ( + "./test-data/example-bundles/ex000", + {"file00": [(TextGridLoader, "file00.TextGrid")]}, + False, + ), + ( + "./test-data/example-bundles/ex001", + { + "file00": [(TextGridLoader, "file00.TextGrid")], + "file01": [(TextGridLoader, "file01.TextGrid")], + }, + False, + ), + ( + "./test-data/example-bundles/ex002", + { + "file00": [ + (MP3Loader, "file00.mp3"), + (TextGridLoader, "file00.TextGrid"), + ] + }, + False, + ), + ( + "./test-data/example-bundles/ex003", + { + "file00": [(TextGridLoader, "file00.TextGrid")], + "file01": [(MP3Loader, "file01.mp3")], + }, + False, + ), + ( + "./test-data/example-bundles/ex004", + { + "file00": [ + (DICOMLoader, "file00.dicom"), + (MP3Loader, "file00.mp3"), + (TextGridLoader, "file00.TextGrid"), + ] + }, + False, + ), + ("./test-data/example-bundles/ex005", {}, False), + ("./test-data/example-bundles/ex006", {}, True), + ( + "./test-data/example-bundles/ex007", + {"file00": [(TextGridLoader, "file00.TextGrid")]}, + True, + ), + ( + "./test-data/example-bundles/ex008", + { + "file00": [ + (DICOMLoader, "file00.dicom"), + (MP3Loader, "file00.mp3"), + (TextGridLoader, "file00.TextGrid"), + ], + "file01": [ + (DICOMLoader, "file01.dicom"), + (MP3Loader, "file01.mp3"), + (TextGridLoader, "file01.TextGrid"), + ], + "file02": [ + (DICOMLoader, "file02.dicom"), + (MP3Loader, "file02.mp3"), + (TextGridLoader, "file02.TextGrid"), + ], + }, + False, + ), + ( + "./test-data/example-bundles/ex009", + {"file00": [(WAVLoader, "file00.wav")]}, + True, + ), + ( + "./test-data/example-bundles/ex010", + {"file00": [(TextGridLoader, "sub00/file00.TextGrid")]}, + False, + ), + ( + "./test-data/example-bundles/ex011", + { + "file00": [(TextGridLoader, "sub00/file00.TextGrid")], + "file01": [(TextGridLoader, "sub01/sub00/sub00/sub00/file01.TextGrid")], + }, + False, + ), + ( + "./test-data/example-bundles/ex012", + { + "file00": [ + (MP3Loader, "sub01/sub00/sub00/sub00/file00.mp3"), + (TextGridLoader, "sub00/file00.TextGrid"), + ] + }, + False, + ), + ( + "./test-data/example-bundles/ex013", + {"file00": [(TextGridLoader, "sub01/file00.TextGrid")]}, + True, + ), + ( + "./test-data/example-bundles/ex014", + {"link00": [(TextGridLoader, "../ex004/file00.TextGrid")]}, + False, + ), + ( + "./test-data/example-bundles/ex015", + { + "file00": [(MP3Loader, "file00.mp3")], + "link00": [(TextGridLoader, "../ex004/file00.TextGrid")], + }, + False, + ), + ( + "./test-data/example-bundles/ex016", + { + "link00": [ + (MP3Loader, "link00.mp3"), + (TextGridLoader, "../ex004/file00.TextGrid"), + ] + }, + False, + ), + ( + "./test-data/ftyers", + { + "20150629171639": [ + (DICOMLoader, "20150629171639.dicom"), + (FLACLoader, "20150629171639.flac"), + (TextGridLoader, "20150629171639.TextGrid"), + ], + }, + True, + ), + ], ) -def test_build_from_dir(mocker, source_dir, bundle_map) -> None: +def test_build_from_dir( + mocker, + source_dir: str, + expected_file_map: Dict[str, Sequence[Tuple[Type[FileLoaderBase], str]]], + should_emit_warning: bool, +) -> None: mock_file_bundle_list_constructor = mocker.patch( "ultratrace2.model.files.bundle.FileBundleList.__init__", return_value=None, ) + mock_warning = mocker.patch("ultratrace2.model.files.bundle.logger.warning") FileBundleList.build_from_dir(source_dir) - mock_file_bundle_list_constructor.assert_called_with(bundle_map) + expected_bundles = {} + for expected_name, expected_files in expected_file_map.items(): + alignment_file = None + image_set_file = None + sound_file = None + for loader, source_subpath in expected_files: + source_path = os.path.abspath(os.path.join(source_dir, source_subpath)) + if issubclass(loader, AlignmentFileLoader): + alignment_file = loader.from_file(source_path) + elif issubclass(loader, ImageSetFileLoader): + image_set_file = loader.from_file(source_path) + elif issubclass(loader, SoundFileLoader): + sound_file = loader.from_file(source_path) + else: + raise RuntimeError("malformed input") + expected_bundles[expected_name] = FileBundle( + name=expected_name, + alignment_file=alignment_file, + image_set_file=image_set_file, + sound_file=sound_file, + ) + mock_file_bundle_list_constructor.assert_called_with(expected_bundles) + if should_emit_warning: + mock_warning.assert_called_once() + else: + mock_warning.assert_not_called()