Skip to content

Commit

Permalink
Merge pull request #125 from SwatPhonLab/test-file-bundles
Browse files Browse the repository at this point in the history
Test file bundles
  • Loading branch information
keggsmurph21 authored Jan 26, 2020
2 parents c09f08f + cc1eef0 commit b14866b
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 29 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ TextGrid==1.5
tqdm==4.40.2
ttkthemes==2.4.0
xparser==0.0.4
python-magic==0.4.15
6 changes: 4 additions & 2 deletions ultratrace2/model/files/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .loaders import FLACLoader

__register(
[".flac"], ["audio/flac"], FLACLoader,
[".flac"], ["audio/flac", "audio/x-flac"], FLACLoader,
)
except ImportError as e:
logger.warning(e)
Expand All @@ -37,7 +37,9 @@
from .loaders import MP3Loader

__register(
[".mp3"], ["audio/mp3", "audio/MPA", "audio/mpa-robust"], MP3Loader,
[".mp3"],
["audio/mp3", "audio/mpeg", "audio/MPA", "audio/mpa-robust"],
MP3Loader,
)
except ImportError as e:
logger.warning(e)
Expand Down
19 changes: 18 additions & 1 deletion ultratrace2/model/files/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ def has_impl(self) -> bool:
for f in [self.alignment_file, self.image_set_file, self.sound_file]
)

def get_alignment_file(self) -> Optional[AlignmentFileLoader]:
return self.alignment_file

def set_alignment_file(self, alignment_file: AlignmentFileLoader) -> None:
if self.alignment_file is not None:
logger.warning("Overwriting existing alignment file")
self.alignment_file = alignment_file

def get_image_set_file(self) -> Optional[ImageSetFileLoader]:
return self.image_set_file

def set_image_set_file(self, image_set_file: ImageSetFileLoader) -> None:
if self.image_set_file is not None:
logger.warning("Overwriting existing image-set file")
self.image_set_file = image_set_file

def get_sound_file(self) -> Optional[SoundFileLoader]:
return self.sound_file

def set_sound_file(self, sound_file: SoundFileLoader) -> None:
if self.sound_file is not None:
logger.warning("Overwriting existing sound file")
Expand All @@ -53,6 +62,14 @@ def set_sound_file(self, sound_file: SoundFileLoader) -> None:
def __repr__(self):
return f'Bundle("{self.name}",{self.alignment_file},{self.image_set_file},{self.sound_file})'

def __eq__(self, other):
return (
self.name == other.name
and self.alignment_file == other.alignment_file
and self.image_set_file == other.image_set_file
and self.sound_file == other.sound_file
)


class FileBundleList:

Expand Down Expand Up @@ -115,7 +132,7 @@ def build_from_dir(
continue

if name not in bundles:
bundles[name] = FileBundle(filepath)
bundles[name] = FileBundle(name)

try:
loaded_file = file_loader.from_file(filepath)
Expand Down
7 changes: 7 additions & 0 deletions ultratrace2/model/files/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def from_file(cls: Type[Self], path: str) -> Self:
NB: If this concrete method fails to load the data at the given path, then
it should throw a `FileLoadError`."""

def __eq__(self, other):
return self.get_path() == other.get_path() and type(self) == type(other)

@staticmethod
def get_priority() -> int:
return 0


class IntervalBase(Protocol):
def get_start(self) -> float:
Expand Down
66 changes: 44 additions & 22 deletions ultratrace2/model/files/registry.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import mimetypes
import magic # type: ignore
import os

from collections import defaultdict
from typing import DefaultDict, Optional, Sequence, Set, Type
from typing import Dict, Mapping, Optional, Sequence, Set, Type, Union

from .loaders.base import (
FileLoaderBase,
AlignmentFileLoader,
ImageSetFileLoader,
SoundFileLoader,
)

from .loaders.base import FileLoaderBase

AbstractLoader = Union[
Type[AlignmentFileLoader], Type[ImageSetFileLoader], Type[SoundFileLoader]
]

# global maps
__extension_to_loaders_map: DefaultDict[str, Set[Type[FileLoaderBase]]] = defaultdict(
set
)
__mime_type_to_loaders_map: DefaultDict[str, Set[Type[FileLoaderBase]]] = defaultdict(
set
)
__extension_to_loaders_map: Dict[str, Type[FileLoaderBase]] = {}
__mime_type_to_loaders_map: Dict[str, Type[FileLoaderBase]] = {}
__loader_priorities_map: Mapping[AbstractLoader, Set[int]] = {
AlignmentFileLoader: set(),
ImageSetFileLoader: set(),
SoundFileLoader: set(),
}


def register_loader_for_extensions_and_mime_types(
Expand All @@ -30,32 +39,45 @@ def register_loader_for_extensions_and_mime_types(
loader_cls: A file loader which knows how to load files with the given file extensions and MIME types
"""

loader_type: AbstractLoader
if issubclass(loader_cls, AlignmentFileLoader):
loader_type = AlignmentFileLoader
elif issubclass(loader_cls, ImageSetFileLoader):
loader_type = ImageSetFileLoader
elif issubclass(loader_cls, SoundFileLoader):
loader_type = SoundFileLoader
else:
raise ValueError(f"Invalid loader class: {loader_cls.__name__}")
priority = loader_cls.get_priority()
if priority in __loader_priorities_map[loader_type]:
raise ValueError(
f"Cannot have duplicate priorities for loader type {loader_type.__name__}"
)

for extension in extensions:
__extension_to_loaders_map[extension].add(loader_cls)
__extension_to_loaders_map[extension] = loader_cls

for mime_type in mime_types:
__mime_type_to_loaders_map[mime_type].add(loader_cls)
__mime_type_to_loaders_map[mime_type] = loader_cls


def get_loader_for(path: str) -> Optional[Type[FileLoaderBase]]:

_, extension = os.path.splitext(path.lower())
mime_type, _ = mimetypes.guess_type(path)
mime_type = magic.Magic(mime=True).from_file(path)
if mime_type is None:
# Early return since we can't possibly match anymore
return None

loader_clses_by_extension = __extension_to_loaders_map[extension]
loader_clses_by_mime_type = __mime_type_to_loaders_map[mime_type]
loader_cls_by_extension = __extension_to_loaders_map.get(extension, None)
loader_cls_by_mime_type = __mime_type_to_loaders_map.get(mime_type, None)

# NB: Use set-intersection (could potentially use set-union instead).
loader_clses = loader_clses_by_extension & loader_clses_by_mime_type

if len(loader_clses) == 0:
if loader_cls_by_extension is None or loader_cls_by_mime_type is None:
return None
elif len(loader_clses) == 1:
return loader_clses.pop()
if loader_cls_by_extension == loader_cls_by_mime_type:
return loader_cls_by_extension
else:
raise NotImplementedError(
f"Found multiple Loaders for path '{path}': {','.join(map(str, loader_clses))}"
raise ValueError(
f"Warning: got {loader_cls_by_extension.__name__} for {extension} and {loader_cls_by_mime_type} for {mime_type}"
)
Loading

0 comments on commit b14866b

Please sign in to comment.