Skip to content

Commit

Permalink
Add SignalInfo typedef, and extension module (#718)
Browse files Browse the repository at this point in the history
This is a part of PRs to add new "sox_io" backend. #726

This PR adds `SignalInfo` structure, which is data exchange interface between Python and C++ in coming TorchScript-based sox IO backend.
For the case, where C++ extension is not available (i.e. Windows), this PR also adds dummy class and module that will be substituted.
This logic is implemented in `torchaudio.extension` moduel.
  • Loading branch information
mthrok authored Jun 18, 2020
1 parent bc1df48 commit f8eac89
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import extension
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
Expand Down
18 changes: 18 additions & 0 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H

#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
namespace {

static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);

} // namespace
} // namespace torchaudio
#endif
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}

int64_t SignalInfo::getNumChannels() const {
return num_channels;
}

int64_t SignalInfo::getNumSamples() const {
return num_samples;
}
} // namespace torchaudio
23 changes: 23 additions & 0 deletions torchaudio/csrc/typedefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H

#include <torch/script.h>

namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
};

} // namespace torchaudio

#endif
7 changes: 7 additions & 0 deletions torchaudio/extension/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .extension import (
_init_extension,
)

_init_extension()

del _init_extension
49 changes: 49 additions & 0 deletions torchaudio/extension/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import warnings
import importlib
from collections import namedtuple

import torch
from torchaudio._internal import module_utils as _mod_utils


def _init_extension():
ext = 'torchaudio._torchaudio'
if _mod_utils.is_module_available(ext):
_init_script_module(ext)
else:
warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()


def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)


def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples

def get_sample_rate(self):
return self.sample_rate

def get_num_channels(self):
return self.num_channels

def get_num_samples(self):
return self.num_samples

DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)

0 comments on commit f8eac89

Please sign in to comment.