diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 9c8a4e4875..16419ff3ff 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -1,3 +1,4 @@ +from . import extension from torchaudio._internal import module_utils as _mod_utils from torchaudio import ( compliance, diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp new file mode 100644 index 0000000000..81b1a84c96 --- /dev/null +++ b/torchaudio/csrc/register.cpp @@ -0,0 +1,18 @@ +#ifndef TORCHAUDIO_REGISTER_H +#define TORCHAUDIO_REGISTER_H + +#include + +namespace torchaudio { +namespace { + +static auto registerSignalInfo = + torch::class_("torchaudio", "SignalInfo") + .def(torch::init()) + .def("get_sample_rate", &SignalInfo::getSampleRate) + .def("get_num_channels", &SignalInfo::getNumChannels) + .def("get_num_samples", &SignalInfo::getNumSamples); + +} // namespace +} // namespace torchaudio +#endif diff --git a/torchaudio/csrc/typedefs.cpp b/torchaudio/csrc/typedefs.cpp new file mode 100644 index 0000000000..7b81d665dc --- /dev/null +++ b/torchaudio/csrc/typedefs.cpp @@ -0,0 +1,23 @@ +#include + +namespace torchaudio { +SignalInfo::SignalInfo( + const int64_t sample_rate_, + const int64_t num_channels_, + const int64_t num_samples_) + : sample_rate(sample_rate_), + num_channels(num_channels_), + num_samples(num_samples_){}; + +int64_t SignalInfo::getSampleRate() const { + return sample_rate; +} + +int64_t SignalInfo::getNumChannels() const { + return num_channels; +} + +int64_t SignalInfo::getNumSamples() const { + return num_samples; +} +} // namespace torchaudio diff --git a/torchaudio/csrc/typedefs.h b/torchaudio/csrc/typedefs.h new file mode 100644 index 0000000000..646ed09f3d --- /dev/null +++ b/torchaudio/csrc/typedefs.h @@ -0,0 +1,23 @@ +#ifndef TORCHAUDIO_TYPDEFS_H +#define TORCHAUDIO_TYPDEFS_H + +#include + +namespace torchaudio { +struct SignalInfo : torch::CustomClassHolder { + int64_t sample_rate; + int64_t num_channels; + int64_t num_samples; + + SignalInfo( + const int64_t sample_rate_, + const int64_t num_channels_, + const int64_t num_samples_); + int64_t getSampleRate() const; + int64_t getNumChannels() const; + int64_t getNumSamples() const; +}; + +} // namespace torchaudio + +#endif diff --git a/torchaudio/extension/__init__.py b/torchaudio/extension/__init__.py new file mode 100644 index 0000000000..d9b6c76fac --- /dev/null +++ b/torchaudio/extension/__init__.py @@ -0,0 +1,7 @@ +from .extension import ( + _init_extension, +) + +_init_extension() + +del _init_extension diff --git a/torchaudio/extension/extension.py b/torchaudio/extension/extension.py new file mode 100644 index 0000000000..4a2ab82124 --- /dev/null +++ b/torchaudio/extension/extension.py @@ -0,0 +1,49 @@ +import warnings +import importlib +from collections import namedtuple + +import torch +from torchaudio._internal import module_utils as _mod_utils + + +def _init_extension(): + ext = 'torchaudio._torchaudio' + if _mod_utils.is_module_available(ext): + _init_script_module(ext) + else: + warnings.warn('torchaudio C++ extension is not available.') + _init_dummy_module() + + +def _init_script_module(module): + path = importlib.util.find_spec(module).origin + torch.classes.load_library(path) + torch.ops.load_library(path) + + +def _init_dummy_module(): + class SignalInfo: + """Data class for audio format information + + Used when torchaudio C++ extension is not available for annotating + sox_io backend functions so that torchaudio is still importable + without extension. + This class has to implement the same interface as C++ equivalent. + """ + def __init__(self, sample_rate: int, num_channels: int, num_samples: int): + self.sample_rate = sample_rate + self.num_channels = num_channels + self.num_samples = num_samples + + def get_sample_rate(self): + return self.sample_rate + + def get_num_channels(self): + return self.num_channels + + def get_num_samples(self): + return self.num_samples + + DummyModule = namedtuple('torchaudio', ['SignalInfo']) + module = DummyModule(SignalInfo) + setattr(torch.classes, 'torchaudio', module)