-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SignalInfo typedef, and extension module (#718)
This is a part of PRs to add new "sox_io" backend. #726 This PR adds `SignalInfo` structure, which is data exchange interface between Python and C++ in coming TorchScript-based sox IO backend. For the case, where C++ extension is not available (i.e. Windows), this PR also adds dummy class and module that will be substituted. This logic is implemented in `torchaudio.extension` moduel.
- Loading branch information
Showing
6 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef TORCHAUDIO_REGISTER_H | ||
#define TORCHAUDIO_REGISTER_H | ||
|
||
#include <torchaudio/csrc/typedefs.h> | ||
|
||
namespace torchaudio { | ||
namespace { | ||
|
||
static auto registerSignalInfo = | ||
torch::class_<SignalInfo>("torchaudio", "SignalInfo") | ||
.def(torch::init<int64_t, int64_t, int64_t>()) | ||
.def("get_sample_rate", &SignalInfo::getSampleRate) | ||
.def("get_num_channels", &SignalInfo::getNumChannels) | ||
.def("get_num_samples", &SignalInfo::getNumSamples); | ||
|
||
} // namespace | ||
} // namespace torchaudio | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include <torchaudio/csrc/typedefs.h> | ||
|
||
namespace torchaudio { | ||
SignalInfo::SignalInfo( | ||
const int64_t sample_rate_, | ||
const int64_t num_channels_, | ||
const int64_t num_samples_) | ||
: sample_rate(sample_rate_), | ||
num_channels(num_channels_), | ||
num_samples(num_samples_){}; | ||
|
||
int64_t SignalInfo::getSampleRate() const { | ||
return sample_rate; | ||
} | ||
|
||
int64_t SignalInfo::getNumChannels() const { | ||
return num_channels; | ||
} | ||
|
||
int64_t SignalInfo::getNumSamples() const { | ||
return num_samples; | ||
} | ||
} // namespace torchaudio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef TORCHAUDIO_TYPDEFS_H | ||
#define TORCHAUDIO_TYPDEFS_H | ||
|
||
#include <torch/script.h> | ||
|
||
namespace torchaudio { | ||
struct SignalInfo : torch::CustomClassHolder { | ||
int64_t sample_rate; | ||
int64_t num_channels; | ||
int64_t num_samples; | ||
|
||
SignalInfo( | ||
const int64_t sample_rate_, | ||
const int64_t num_channels_, | ||
const int64_t num_samples_); | ||
int64_t getSampleRate() const; | ||
int64_t getNumChannels() const; | ||
int64_t getNumSamples() const; | ||
}; | ||
|
||
} // namespace torchaudio | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .extension import ( | ||
_init_extension, | ||
) | ||
|
||
_init_extension() | ||
|
||
del _init_extension |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import warnings | ||
import importlib | ||
from collections import namedtuple | ||
|
||
import torch | ||
from torchaudio._internal import module_utils as _mod_utils | ||
|
||
|
||
def _init_extension(): | ||
ext = 'torchaudio._torchaudio' | ||
if _mod_utils.is_module_available(ext): | ||
_init_script_module(ext) | ||
else: | ||
warnings.warn('torchaudio C++ extension is not available.') | ||
_init_dummy_module() | ||
|
||
|
||
def _init_script_module(module): | ||
path = importlib.util.find_spec(module).origin | ||
torch.classes.load_library(path) | ||
torch.ops.load_library(path) | ||
|
||
|
||
def _init_dummy_module(): | ||
class SignalInfo: | ||
"""Data class for audio format information | ||
Used when torchaudio C++ extension is not available for annotating | ||
sox_io backend functions so that torchaudio is still importable | ||
without extension. | ||
This class has to implement the same interface as C++ equivalent. | ||
""" | ||
def __init__(self, sample_rate: int, num_channels: int, num_samples: int): | ||
self.sample_rate = sample_rate | ||
self.num_channels = num_channels | ||
self.num_samples = num_samples | ||
|
||
def get_sample_rate(self): | ||
return self.sample_rate | ||
|
||
def get_num_channels(self): | ||
return self.num_channels | ||
|
||
def get_num_samples(self): | ||
return self.num_samples | ||
|
||
DummyModule = namedtuple('torchaudio', ['SignalInfo']) | ||
module = DummyModule(SignalInfo) | ||
setattr(torch.classes, 'torchaudio', module) |