From fddfbd1169d42ec01eb0f37700d974a5a0f7c408 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 5 Jan 2021 00:47:57 -0500 Subject: [PATCH] guard for transducer --- .circleci/torchscript_bc_test/common.sh | 2 +- .circleci/unittest/linux/scripts/install.sh | 2 +- build_tools/setup_helpers/extension.py | 17 +++++++++++------ packaging/build_wheel.sh | 2 +- packaging/torchaudio/build.sh | 2 +- torchaudio/csrc/register.cpp | 2 ++ torchaudio/csrc/transducer.cpp | 4 ++++ 7 files changed, 21 insertions(+), 10 deletions(-) diff --git a/.circleci/torchscript_bc_test/common.sh b/.circleci/torchscript_bc_test/common.sh index 9b1d0e1ef4..6f64e7204c 100644 --- a/.circleci/torchscript_bc_test/common.sh +++ b/.circleci/torchscript_bc_test/common.sh @@ -67,5 +67,5 @@ build_master() { printf "* Installing torchaudio\n" cd "${_root_dir}" || exit 1 git submodule update --init --recursive - BUILD_SOX=1 python setup.py clean install + BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py clean install } diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 27dec79251..e7543c762a 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit} # 2. Install torchaudio printf "* Installing torchaudio\n" -BUILD_SOX=1 python setup.py install +BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install # 3. Install Test tools printf "* Installing test tools\n" diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 6a30093238..07ee05b520 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -20,20 +20,21 @@ _TP_INSTALL_DIR = _TP_BASE_DIR / 'install' -def _get_build_sox(): - val = os.environ.get('BUILD_SOX', '0') +def _get_build(var): + val = os.environ.get(var, '0') trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] if val in trues: return True if val not in falses: print( - f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. ' + f'WARNING: Unexpected environment variable value `{var}={val}`. ' f'Expected one of {trues + falses}') return False -_BUILD_SOX = _get_build_sox() +_BUILD_SOX = _get_build("BUILD_SOX") +_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER") def _get_eca(debug): @@ -42,6 +43,8 @@ def _get_eca(debug): eca += ["-O0", "-g"] else: eca += ["-O3"] + if _BUILD_TRANSDUCER: + eca += ['-DBUILD_TRANSDUCER'] return eca @@ -64,10 +67,11 @@ def _get_srcs(): def _get_include_dirs(): dirs = [ str(_ROOT_DIR), - str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'), ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) + if _BUILD_TRANSDUCER: + dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include')) return dirs @@ -95,7 +99,8 @@ def _get_extra_objects(): ] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) - objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a')) + if _BUILD_TRANSDUCER: + objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a')) return objs diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index ad45a08d59..957b41e048 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag else - BUILD_SOX=1 python setup.py bdist_wheel + BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py bdist_wheel fi diff --git a/packaging/torchaudio/build.sh b/packaging/torchaudio/build.sh index 99c17b6913..88bbfce375 100644 --- a/packaging/torchaudio/build.sh +++ b/packaging/torchaudio/build.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash set -ex -BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt +BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 7ae91f1d7a..18ac86f388 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -81,6 +81,7 @@ TORCH_LIBRARY(torchaudio, m) { ////////////////////////////////////////////////////////////////////////////// // transducer.cpp ////////////////////////////////////////////////////////////////////////////// + #ifdef BUILD_TRANSDUCER m.def("rnnt_loss(Tensor acts," "Tensor labels," "Tensor input_lengths," @@ -89,5 +90,6 @@ TORCH_LIBRARY(torchaudio, m) { "Tensor grads," "int blank_label," "int num_threads) -> int"); + #endif } #endif diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp index c7eb337fe5..2d2b7a8b51 100644 --- a/torchaudio/csrc/transducer.cpp +++ b/torchaudio/csrc/transducer.cpp @@ -1,3 +1,5 @@ +#ifdef BUILD_TRANSDUCER + #include #include #include @@ -76,3 +78,5 @@ int64_t cpu_rnnt_loss(torch::Tensor acts, TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { m.impl("rnnt_loss", &cpu_rnnt_loss); } + +#endif