Skip to content

Commit

Permalink
guard for transducer
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Jan 5, 2021
1 parent ab23f3c commit 64c8220
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .circleci/torchscript_bc_test/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ build_master() {
printf "* Installing torchaudio\n"
cd "${_root_dir}" || exit 1
git submodule update --init --recursive
BUILD_SOX=1 python setup.py clean install
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py clean install
}
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}

# 2. Install torchaudio
printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py install
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install

# 3. Install Test tools
printf "* Installing test tools\n"
Expand Down
17 changes: 11 additions & 6 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'


def _get_build_sox():
val = os.environ.get('BUILD_SOX', '0')
def _get_build(var):
val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
if val in trues:
return True
if val not in falses:
print(
f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. '
f'WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}')
return False


_BUILD_SOX = _get_build_sox()
_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")


def _get_eca(debug):
Expand All @@ -42,6 +43,8 @@ def _get_eca(debug):
eca += ["-O0", "-g"]
else:
eca += ["-O3"]
if _BUILD_TRANSDUCER:
eca += ['-DBUILD_TRANSDUCER']
return eca


Expand All @@ -64,10 +67,11 @@ def _get_srcs():
def _get_include_dirs():
dirs = [
str(_ROOT_DIR),
str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'),
]
if _BUILD_SOX:
dirs.append(str(_TP_INSTALL_DIR / 'include'))
if _BUILD_TRANSDUCER:
dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
return dirs


Expand Down Expand Up @@ -95,7 +99,8 @@ def _get_extra_objects():
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
if _BUILD_TRANSDUCER:
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
return objs


Expand Down
2 changes: 1 addition & 1 deletion packaging/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
else
BUILD_SOX=1 python setup.py bdist_wheel
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py bdist_wheel
fi
2 changes: 1 addition & 1 deletion packaging/torchaudio/build.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env bash
set -ex

BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
2 changes: 2 additions & 0 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ TORCH_LIBRARY(torchaudio, m) {
//////////////////////////////////////////////////////////////////////////////
// transducer.cpp
//////////////////////////////////////////////////////////////////////////////
#ifdef BUILD_TRANSDUCER
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
Expand All @@ -89,5 +90,6 @@ TORCH_LIBRARY(torchaudio, m) {
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
#endif
}
#endif
4 changes: 4 additions & 0 deletions torchaudio/csrc/transducer.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#ifdef BUILD_TRANSDUCER

#include <iostream>
#include <numeric>
#include <string>
Expand Down Expand Up @@ -76,3 +78,5 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss);
}

#endif

0 comments on commit 64c8220

Please sign in to comment.