Skip to content

Commit

Permalink
Add first ideas
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Jul 8, 2024
1 parent 5f4a293 commit 6dcd0a7
Show file tree
Hide file tree
Showing 3 changed files with 1,317 additions and 69 deletions.
241 changes: 172 additions & 69 deletions audinterface/core/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,51 +248,83 @@ def _process_file(
) -> typing.Tuple[
typing.List[typing.Any],
typing.List[str],
typing.List[pd.Timedelta],
typing.List[pd.Timedelta],
typing.Optional[typing.List[pd.Timedelta]],
typing.Optional[typing.List[pd.Timedelta]],
]:
r"""Process a file.
Args:
file: file path
root: optional root path of file
start: start time to read media file
end: end time to read media file
process_func_args: arguments to pass to process function
Returns:
result of processing function, files, starts, ends
"""
if start is not None:
start = utils.to_timedelta(start, self.sampling_rate)
if end is not None:
end = utils.to_timedelta(end, self.sampling_rate)

signal, sampling_rate = utils.read_audio(
file,
start=start,
end=end,
root=root,
)
ext = audeer.file_extension(file).lower()

y, files, starts, ends = self._process_signal(
signal,
sampling_rate,
idx=idx,
root=root,
file=file,
process_func_args=process_func_args,
)

def precision_offset(duration, sampling_rate):
# Ensure we get the same precision
# by storing what is lost due to rounding
# when reading the file
duration_at_sample = utils.to_timedelta(
audmath.samples(duration.total_seconds(), sampling_rate) / sampling_rate
# Text files
if ext in ["json", "txt"]:
data = utils.read_text(file, root=root)
y = self._call_data(
data,
idx=idx,
root=root,
file=file,
process_func_args=process_func_args,
)
return duration - duration_at_sample
files = [file]
starts = None
ends = None

if self.win_dur is not None:
if start is not None:
starts = starts + start
ends = ends + start
# Audio/video files
else:
if start is not None and not pd.isna(start):
starts[0] += start
ends[0] += start - precision_offset(start, sampling_rate)
if self.keep_nat and (end is None or pd.isna(end)):
ends[0] = pd.NaT
if end is not None and not pd.isna(end):
ends[-1] += precision_offset(end, sampling_rate)
signal, sampling_rate = utils.read_audio(
file,
start=start,
end=end,
root=root,
)

y, files, starts, ends = self._process_signal(
signal,
sampling_rate,
idx=idx,
root=root,
file=file,
process_func_args=process_func_args,
)

def precision_offset(duration, sampling_rate):
# Ensure we get the same precision
# by storing what is lost due to rounding
# when reading the file
duration_at_sample = utils.to_timedelta(
audmath.samples(duration.total_seconds(), sampling_rate)
/ sampling_rate
)
return duration - duration_at_sample

if self.win_dur is not None:
if start is not None:
starts = starts + start
ends = ends + start
else:
if start is not None and not pd.isna(start):
starts[0] += start
ends[0] += start - precision_offset(start, sampling_rate)
if self.keep_nat and (end is None or pd.isna(end)):
ends[0] = pd.NaT
if end is not None and not pd.isna(end):
ends[-1] += precision_offset(end, sampling_rate)

return y, files, starts, ends

Expand Down Expand Up @@ -348,7 +380,6 @@ def process_file(
end=end,
process_func_args=process_func_args,
)

index = audformat.segmented_index(files, starts, ends)

if len(y) == 0:
Expand Down Expand Up @@ -714,7 +745,7 @@ def _process_signal(
def process_signal(
self,
signal: np.ndarray,
sampling_rate: int,
sampling_rate: int = None,
*,
file: str = None,
start: Timestamp = None,
Expand Down Expand Up @@ -768,24 +799,31 @@ def process_signal(
process_func_args=process_func_args,
)
else:
if start is not None:
start = utils.to_timedelta(start, sampling_rate)
if end is not None:
end = utils.to_timedelta(end, sampling_rate)

y, files, starts, ends = self._process_signal(
signal,
sampling_rate,
file=file,
start=start,
end=end,
process_func_args=process_func_args,
)
# Text files
if sampling_rate is None:
pass
# Implement

if file is not None:
index = audformat.segmented_index(files, starts, ends)
# Audio/video files
else:
index = utils.signal_index(starts, ends)
if start is not None:
start = utils.to_timedelta(start, sampling_rate)
if end is not None:
end = utils.to_timedelta(end, sampling_rate)

y, files, starts, ends = self._process_signal(
signal,
sampling_rate,
file=file,
start=start,
end=end,
process_func_args=process_func_args,
)

if file is not None:
index = audformat.segmented_index(files, starts, ends)
else:
index = utils.signal_index(starts, ends)

if len(y) == 0:
return pd.Series([], index, dtype=object)
Expand Down Expand Up @@ -920,7 +958,28 @@ def _call(
file: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> typing.Any:
r"""Call processing function, possibly pass special args."""
r"""Call processing function on audio/video files.
Assumes a ``numpy`` array as signal,
with channels and samples as dimensions.
The signal is resampled and/or remixed,
if required.
Special arguments are extracted,
and passed to the processing function.
Args:
signal: signal values
sampling_rate: sampling rate in Hz
idx: index
root: root path
file: file path
process_func_args: processing function arguments
Returns:
result of processing function
"""
signal, sampling_rate = utils.preprocess_signal(
signal,
sampling_rate,
Expand All @@ -931,14 +990,7 @@ def _call(
)

process_func_args = process_func_args or self.process_func_args
special_args = {}
for key, value in [
("idx", idx),
("root", root),
("file", file),
]:
if key in self._process_func_signature and key not in process_func_args:
special_args[key] = value
special_args = self._special_args(idx, root, file, process_func_args)

def _helper(x):
if self.process_func_is_mono:
Expand Down Expand Up @@ -973,18 +1025,66 @@ def _helper(x):

return y

def _call_data(
self,
data: typing.Any,
*,
idx: int = 0,
root: str = None,
file: str = None,
process_func_args: typing.Dict[str, typing.Any] = None,
) -> typing.Any:
r"""Call processing function on general data."""
process_func_args = process_func_args or self.process_func_args
special_args = self._special_args(idx, root, file, process_func_args)
y = self.process_func(data, **special_args, **process_func_args)
return y

def _special_args(
self,
idx: int,
root: typing.Optional[str],
file: typing.Optional[str],
process_func_args: typing.Dict[str, typing.Any] = None,
) -> typing.Dict[str, typing.Union[int, str]]:
r"""Identify special arguments in processing function.
If one of the arguments of the processing function is named
``"idx"``, ``"root"``, or ``"file"``,
and not provided in ``process_func_args``,
it is identified as a special argument.
Args:
idx: index
root: root path
file: file path
process_func_args: processing function arguments
Returns:
special arguments dictionary
"""
special_args = {}
for key, value in [("idx", idx), ("root", root), ("file", file)]:
if key in self._process_func_signature and key not in process_func_args:
special_args[key] = value
return special_args

def __call__(
self,
signal: np.ndarray,
sampling_rate: int,
sampling_rate: int = None,
) -> typing.Any:
r"""Apply processing to signal.
This function processes the signal **without** transforming the output
into a :class:`pd.Series`. Instead, it will return the raw processed
signal. However, if channel selection, mixdown and/or resampling
is enabled, the signal will be first remixed and resampled if the
input sampling rate does not fit the expected sampling rate.
This function processes the signal
**without** transforming the output into a :class:`pd.Series`.
Instead, it will return the raw processed signal.
However,
if channel selection, mixdown and/or resampling is enabled,
and ``sampling_rate`` is not ``None``,
the signal will be first remixed and resampled
if the input sampling rate does not fit the expected sampling rate.
Args:
signal: signal values
Expand All @@ -998,4 +1098,7 @@ def __call__(
RuntimeError: if channel selection is invalid
"""
return self._call(signal, sampling_rate)
if sampling_rate is not None:
return self._call(signal, sampling_rate)
else:
return self._call_data(signal)
32 changes: 32 additions & 0 deletions audinterface/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import json
import os
import typing

Expand Down Expand Up @@ -148,6 +149,37 @@ def read_audio(
return signal, sampling_rate


def read_text(
file: str,
*,
root: str = None,
) -> typing.Union[dict, str]:
"""Reads text file.
Args:
file: path to audio file
root: root folder
Returns:
dictionary with values,
if ``file`` is a json file,
else content of file as string
"""
if root is not None and not os.path.isabs(file):
file = os.path.join(root, file)

ext = audeer.file_extension(file).lower()
if ext == "json":
with open(file) as json_file:
data = json.load(f)
elif ext == "txt":
with open(file) as txt_file:
data = txt_file.read()

return data


def segment_to_indices(
signal: np.ndarray,
sampling_rate: int,
Expand Down
Loading

0 comments on commit 6dcd0a7

Please sign in to comment.